diff --git a/lama_cleaner/const.py b/lama_cleaner/const.py
index b506349..a28beba 100644
--- a/lama_cleaner/const.py
+++ b/lama_cleaner/const.py
@@ -4,6 +4,11 @@ from enum import Enum
from pydantic import BaseModel
+DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
+DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
+DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
+DIFFUSERS_SDXL_INPAINT_CLASS_NAME = "StableDiffusionXLInpaintPipeline"
+
MPS_UNSUPPORT_MODELS = [
"lama",
"ldm",
@@ -15,22 +20,8 @@ MPS_UNSUPPORT_MODELS = [
]
DEFAULT_MODEL = "lama"
-AVAILABLE_MODELS = [
- "lama",
- "ldm",
- "zits",
- "mat",
- "fcf",
- "manga",
- "cv2",
-]
-DIFFUSERS_MODEL_FP16_REVERSION = [
- "runwayml/stable-diffusion-inpainting",
- "Sanster/anything-4.0-inpainting",
- "Sanster/Realistic_Vision_V1.4-inpainting",
- "stabilityai/stable-diffusion-2-inpainting",
- "timbrooks/instruct-pix2pix",
-]
+AVAILABLE_MODELS = ["lama", "ldm", "zits", "mat", "fcf", "manga", "cv2", "migan"]
+
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
DEFAULT_DEVICE = "cuda"
diff --git a/lama_cleaner/download.py b/lama_cleaner/download.py
index 66b14f9..6e90f4c 100644
--- a/lama_cleaner/download.py
+++ b/lama_cleaner/download.py
@@ -5,23 +5,23 @@ from typing import List
from loguru import logger
from pathlib import Path
-from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION, DEFAULT_MODEL_DIR
-from lama_cleaner.runtime import setup_model_dir
-from lama_cleaner.schema import (
- ModelInfo,
- ModelType,
- DIFFUSERS_SD_INPAINT_CLASS_NAME,
- DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
+from lama_cleaner.const import (
+ DEFAULT_MODEL_DIR,
DIFFUSERS_SD_CLASS_NAME,
+ DIFFUSERS_SD_INPAINT_CLASS_NAME,
DIFFUSERS_SDXL_CLASS_NAME,
+ DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
)
+from lama_cleaner.model.utils import handle_from_pretrained_exceptions
+from lama_cleaner.model_info import ModelInfo, ModelType
+from lama_cleaner.runtime import setup_model_dir
def cli_download_model(model: str, model_dir: Path):
setup_model_dir(model_dir)
from lama_cleaner.model import models
- if model in models:
+ if model in models and models[model].is_erase_model:
logger.info(f"Downloading {model}...")
models[model].download()
logger.info(f"Done.")
@@ -29,9 +29,10 @@ def cli_download_model(model: str, model_dir: Path):
logger.info(f"Downloading model from Huggingface: {model}")
from diffusers import DiffusionPipeline
- downloaded_path = DiffusionPipeline.download(
+ downloaded_path = handle_from_pretrained_exceptions(
+ DiffusionPipeline.download,
pretrained_model_name=model,
- variant="fp16" if model in DIFFUSERS_MODEL_FP16_REVERSION else "main",
+ variant="fp16",
resume_download=True,
)
logger.info(f"Done. Downloaded to {downloaded_path}")
@@ -43,21 +44,33 @@ def folder_name_to_show_name(name: str) -> str:
def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
cache_dir = Path(cache_dir)
+ stable_diffusion_dir = cache_dir / "stable_diffusion"
+ stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
# logger.info(f"Scanning single file sd/sdxl models in {cache_dir}")
res = []
- for it in cache_dir.glob(f"*.*"):
+ for it in stable_diffusion_dir.glob(f"*.*"):
if it.suffix not in [".safetensors", ".ckpt"]:
continue
if "inpaint" in str(it).lower():
- if "sdxl" in str(it).lower():
- model_type = ModelType.DIFFUSERS_SDXL_INPAINT
- else:
- model_type = ModelType.DIFFUSERS_SD_INPAINT
+ model_type = ModelType.DIFFUSERS_SD_INPAINT
else:
- if "sdxl" in str(it).lower():
- model_type = ModelType.DIFFUSERS_SDXL
- else:
- model_type = ModelType.DIFFUSERS_SD
+ model_type = ModelType.DIFFUSERS_SD
+ res.append(
+ ModelInfo(
+ name=it.name,
+ path=str(it.absolute()),
+ model_type=model_type,
+ is_single_file_diffusers=True,
+ )
+ )
+
+ for it in stable_diffusion_xl_dir.glob(f"*.*"):
+ if it.suffix not in [".safetensors", ".ckpt"]:
+ continue
+ if "inpaint" in str(it).lower():
+ model_type = ModelType.DIFFUSERS_SDXL_INPAINT
+ else:
+ model_type = ModelType.DIFFUSERS_SDXL
res.append(
ModelInfo(
name=it.name,
@@ -104,8 +117,9 @@ def scan_models() -> List[ModelInfo]:
name = folder_name_to_show_name(it.parent.parent.parent.name)
if name in diffusers_model_names:
continue
-
- if _class_name == DIFFUSERS_SD_CLASS_NAME:
+ if "PowerPaint" in name:
+ model_type = ModelType.DIFFUSERS_OTHER
+ elif _class_name == DIFFUSERS_SD_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD_INPAINT
diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py
index 1c12128..f48cecf 100644
--- a/lama_cleaner/helper.py
+++ b/lama_cleaner/helper.py
@@ -290,3 +290,7 @@ def only_keep_largest_contour(mask: np.ndarray) -> List[np.ndarray]:
return cv2.drawContours(new_mask, contours, max_index, 255, -1)
else:
return mask
+
+
+def is_mac():
+ return sys.platform == "darwin"
diff --git a/lama_cleaner/model/__init__.py b/lama_cleaner/model/__init__.py
index 1892ab7..473cb99 100644
--- a/lama_cleaner/model/__init__.py
+++ b/lama_cleaner/model/__init__.py
@@ -9,6 +9,7 @@ from .mat import MAT
from .mi_gan import MIGAN
from .opencv2 import OpenCV2
from .paint_by_example import PaintByExample
+from .power_paint.power_paint import PowerPaint
from .sd import SD15, SD2, Anything4, RealisticVision14, SD
from .sdxl import SDXL
from .zits import ZITS
@@ -30,4 +31,5 @@ models = {
InstructPix2Pix.name: InstructPix2Pix,
Kandinsky22.name: Kandinsky22,
SDXL.name: SDXL,
+ PowerPaint.name: PowerPaint,
}
diff --git a/lama_cleaner/model/base.py b/lama_cleaner/model/base.py
index b4b43ba..dd65d55 100644
--- a/lama_cleaner/model/base.py
+++ b/lama_cleaner/model/base.py
@@ -14,7 +14,7 @@ from lama_cleaner.helper import (
)
from lama_cleaner.model.helper.g_diffuser_bot import expand_image
from lama_cleaner.model.utils import get_scheduler
-from lama_cleaner.schema import Config, HDStrategy, SDSampler, ModelInfo
+from lama_cleaner.schema import Config, HDStrategy, SDSampler
class InpaintModel:
@@ -271,7 +271,7 @@ class InpaintModel:
class DiffusionInpaintModel(InpaintModel):
def __init__(self, device, **kwargs):
- self.model_info: ModelInfo = kwargs["model_info"]
+ self.model_info = kwargs["model_info"]
self.model_id_or_path = self.model_info.path
super().__init__(device, **kwargs)
diff --git a/lama_cleaner/model/controlnet.py b/lama_cleaner/model/controlnet.py
index 29591f8..b7288f8 100644
--- a/lama_cleaner/model/controlnet.py
+++ b/lama_cleaner/model/controlnet.py
@@ -5,7 +5,6 @@ import torch
from diffusers import ControlNetModel, DiffusionPipeline
from loguru import logger
-from lama_cleaner.const import DIFFUSERS_MODEL_FP16_REVERSION
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.helper.controlnet_preprocess import (
make_canny_control_image,
@@ -14,8 +13,8 @@ from lama_cleaner.model.helper.controlnet_preprocess import (
make_inpaint_control_image,
)
from lama_cleaner.model.helper.cpu_text_encoder import CPUTextEncoderWrapper
-from lama_cleaner.model.utils import get_scheduler
-from lama_cleaner.schema import Config, ModelInfo, ModelType
+from lama_cleaner.model.utils import get_scheduler, handle_from_pretrained_exceptions
+from lama_cleaner.schema import Config, ModelType
class ControlNet(DiffusionInpaintModel):
@@ -39,11 +38,11 @@ class ControlNet(DiffusionInpaintModel):
def init_model(self, device: torch.device, **kwargs):
fp16 = not kwargs.get("no_half", False)
- model_info: ModelInfo = kwargs["model_info"]
- sd_controlnet_method = kwargs["sd_controlnet_method"]
+ model_info = kwargs["model_info"]
+ controlnet_method = kwargs["controlnet_method"]
self.model_info = model_info
- self.sd_controlnet_method = sd_controlnet_method
+ self.controlnet_method = controlnet_method
model_kwargs = {}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
@@ -76,7 +75,8 @@ class ControlNet(DiffusionInpaintModel):
)
controlnet = ControlNetModel.from_pretrained(
- sd_controlnet_method, torch_dtype=torch_dtype, resume_download=True
+ pretrained_model_name_or_path=controlnet_method,
+ resume_download=True,
)
if model_info.is_single_file_diffusers:
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
@@ -88,17 +88,12 @@ class ControlNet(DiffusionInpaintModel):
model_info.path, controlnet=controlnet, **model_kwargs
).to(torch_dtype)
else:
- self.model = PipeClass.from_pretrained(
- model_info.path,
+ self.model = handle_from_pretrained_exceptions(
+ PipeClass.from_pretrained,
+ pretrained_model_name_or_path=model_info.path,
controlnet=controlnet,
- revision="fp16"
- if (
- model_info.path in DIFFUSERS_MODEL_FP16_REVERSION
- and use_gpu
- and fp16
- )
- else "main",
- torch_dtype=torch_dtype,
+ variant="fp16",
+ dtype=torch_dtype,
**model_kwargs,
)
@@ -116,23 +111,23 @@ class ControlNet(DiffusionInpaintModel):
self.callback = kwargs.pop("callback", None)
def switch_controlnet_method(self, new_method: str):
- self.sd_controlnet_method = new_method
+ self.controlnet_method = new_method
controlnet = ControlNetModel.from_pretrained(
new_method, torch_dtype=self.torch_dtype, resume_download=True
).to(self.model.device)
self.model.controlnet = controlnet
def _get_control_image(self, image, mask):
- if "canny" in self.sd_controlnet_method:
+ if "canny" in self.controlnet_method:
control_image = make_canny_control_image(image)
- elif "openpose" in self.sd_controlnet_method:
+ elif "openpose" in self.controlnet_method:
control_image = make_openpose_control_image(image)
- elif "depth" in self.sd_controlnet_method:
+ elif "depth" in self.controlnet_method:
control_image = make_depth_control_image(image)
- elif "inpaint" in self.sd_controlnet_method:
+ elif "inpaint" in self.controlnet_method:
control_image = make_inpaint_control_image(image, mask)
else:
- raise NotImplementedError(f"{self.sd_controlnet_method} not implemented")
+ raise NotImplementedError(f"{self.controlnet_method} not implemented")
return control_image
def forward(self, image, mask, config: Config):
diff --git a/lama_cleaner/model/kandinsky.py b/lama_cleaner/model/kandinsky.py
index 783a01b..0645af7 100644
--- a/lama_cleaner/model/kandinsky.py
+++ b/lama_cleaner/model/kandinsky.py
@@ -24,7 +24,7 @@ class Kandinsky(DiffusionInpaintModel):
}
self.model = AutoPipelineForInpainting.from_pretrained(
- self.model_id_or_path, **model_kwargs
+ self.name, **model_kwargs
).to(device)
self.callback = kwargs.pop("callback", None)
@@ -66,4 +66,3 @@ class Kandinsky(DiffusionInpaintModel):
class Kandinsky22(Kandinsky):
name = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
- model_id_or_path = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
diff --git a/lama_cleaner/model/mi_gan.py b/lama_cleaner/model/mi_gan.py
index 3e3f200..d8ec0fa 100644
--- a/lama_cleaner/model/mi_gan.py
+++ b/lama_cleaner/model/mi_gan.py
@@ -16,7 +16,7 @@ from lama_cleaner.model.base import InpaintModel
MIGAN_MODEL_URL = os.environ.get(
"MIGAN_MODEL_URL",
- "/Users/cwq/code/github/MI-GAN/exported_models/migan_places512/models/migan_traced.pt",
+ "https://github.com/Sanster/models/releases/download/migan/migan_traced.pt",
)
MIGAN_MODEL_MD5 = os.environ.get("MIGAN_MODEL_MD5", "76eb3b1a71c400ee3290524f7a11b89c")
diff --git a/lama_cleaner/model/paint_by_example.py b/lama_cleaner/model/paint_by_example.py
index 80b9745..07d3842 100644
--- a/lama_cleaner/model/paint_by_example.py
+++ b/lama_cleaner/model/paint_by_example.py
@@ -28,7 +28,7 @@ class PaintByExample(DiffusionInpaintModel):
)
self.model = DiffusionPipeline.from_pretrained(
- "Fantasy-Studio/Paint-by-Example", torch_dtype=torch_dtype, **model_kwargs
+ self.name, torch_dtype=torch_dtype, **model_kwargs
)
# TODO: gpu_id
diff --git a/lama_cleaner/model/pipeline/__init__.py b/lama_cleaner/model/pipeline/__init__.py
deleted file mode 100644
index 9056bc6..0000000
--- a/lama_cleaner/model/pipeline/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-from .pipeline_stable_diffusion_controlnet_inpaint import (
- StableDiffusionControlNetInpaintPipeline,
-)
diff --git a/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py b/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py
deleted file mode 100644
index f65e95d..0000000
--- a/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py
+++ /dev/null
@@ -1,638 +0,0 @@
-# Copyright 2023 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import gc
-from typing import Union, List, Optional, Callable, Dict, Any
-
-# Copy from https://github.com/mikonvergence/ControlNetInpaint/blob/main/src/pipeline_stable_diffusion_controlnet_inpaint.py
-
-import torch
-import PIL.Image
-
-from diffusers.pipelines.controlnet.pipeline_controlnet import *
-from diffusers.utils import replace_example_docstring
-
-EXAMPLE_DOC_STRING = """
- Examples:
- ```py
- >>> # !pip install opencv-python transformers accelerate
- >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
- >>> from diffusers.utils import load_image
- >>> import numpy as np
- >>> import torch
-
- >>> import cv2
- >>> from PIL import Image
- >>> # download an image
- >>> image = load_image(
- ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
- ... )
- >>> image = np.array(image)
- >>> mask_image = load_image(
- ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
- ... )
- >>> mask_image = np.array(mask_image)
- >>> # get canny image
- >>> canny_image = cv2.Canny(image, 100, 200)
- >>> canny_image = canny_image[:, :, None]
- >>> canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
- >>> canny_image = Image.fromarray(canny_image)
-
- >>> # load control net and stable diffusion v1-5
- >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
- >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
- ... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
- ... )
-
- >>> # speed up diffusion process with faster scheduler and memory optimization
- >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
- >>> # remove following line if xformers is not installed
- >>> pipe.enable_xformers_memory_efficient_attention()
-
- >>> pipe.enable_model_cpu_offload()
-
- >>> # generate image
- >>> generator = torch.manual_seed(0)
- >>> image = pipe(
- ... "futuristic-looking doggo",
- ... num_inference_steps=20,
- ... generator=generator,
- ... image=image,
- ... control_image=canny_image,
- ... mask_image=mask_image
- ... ).images[0]
- ```
-"""
-
-
-def prepare_mask_and_masked_image(image, mask):
- """
- Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
- converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
- ``image`` and ``1`` for the ``mask``.
- The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
- binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
- Args:
- image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
- It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
- ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
- mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
- It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
- ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
- Raises:
- ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
- should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
- TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
- (ot the other way around).
- Returns:
- tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
- dimensions: ``batch x channels x height x width``.
- """
- if isinstance(image, torch.Tensor):
- if not isinstance(mask, torch.Tensor):
- raise TypeError(
- f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not"
- )
-
- # Batch single image
- if image.ndim == 3:
- assert (
- image.shape[0] == 3
- ), "Image outside a batch should be of shape (3, H, W)"
- image = image.unsqueeze(0)
-
- # Batch and add channel dim for single mask
- if mask.ndim == 2:
- mask = mask.unsqueeze(0).unsqueeze(0)
-
- # Batch single mask or add channel dim
- if mask.ndim == 3:
- # Single batched mask, no channel dim or single mask not batched but channel dim
- if mask.shape[0] == 1:
- mask = mask.unsqueeze(0)
-
- # Batched masks no channel dim
- else:
- mask = mask.unsqueeze(1)
-
- assert (
- image.ndim == 4 and mask.ndim == 4
- ), "Image and Mask must have 4 dimensions"
- assert (
- image.shape[-2:] == mask.shape[-2:]
- ), "Image and Mask must have the same spatial dimensions"
- assert (
- image.shape[0] == mask.shape[0]
- ), "Image and Mask must have the same batch size"
-
- # Check image is in [-1, 1]
- if image.min() < -1 or image.max() > 1:
- raise ValueError("Image should be in [-1, 1] range")
-
- # Check mask is in [0, 1]
- if mask.min() < 0 or mask.max() > 1:
- raise ValueError("Mask should be in [0, 1] range")
-
- # Binarize mask
- mask[mask < 0.5] = 0
- mask[mask >= 0.5] = 1
-
- # Image as float32
- image = image.to(dtype=torch.float32)
- elif isinstance(mask, torch.Tensor):
- raise TypeError(
- f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not"
- )
- else:
- # preprocess image
- if isinstance(image, (PIL.Image.Image, np.ndarray)):
- image = [image]
-
- if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
- image = [np.array(i.convert("RGB"))[None, :] for i in image]
- image = np.concatenate(image, axis=0)
- elif isinstance(image, list) and isinstance(image[0], np.ndarray):
- image = np.concatenate([i[None, :] for i in image], axis=0)
-
- image = image.transpose(0, 3, 1, 2)
- image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
-
- # preprocess mask
- if isinstance(mask, (PIL.Image.Image, np.ndarray)):
- mask = [mask]
-
- if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
- mask = np.concatenate(
- [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
- )
- mask = mask.astype(np.float32) / 255.0
- elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
- mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
-
- mask[mask < 0.5] = 0
- mask[mask >= 0.5] = 1
- mask = torch.from_numpy(mask)
-
- masked_image = image * (mask < 0.5)
-
- return mask, masked_image
-
-
-class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline):
- r"""
- Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance.
-
- This model inherits from [`StableDiffusionControlNetPipeline`]. Check the superclass documentation for the generic methods the
- library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
-
- Args:
- vae ([`AutoencoderKL`]):
- Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
- text_encoder ([`CLIPTextModel`]):
- Frozen text-encoder. Stable Diffusion uses the text portion of
- [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
- the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
- tokenizer (`CLIPTokenizer`):
- Tokenizer of class
- [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
- unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
- controlnet ([`ControlNetModel`]):
- Provides additional conditioning to the unet during the denoising process
- scheduler ([`SchedulerMixin`]):
- A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
- [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
- safety_checker ([`StableDiffusionSafetyChecker`]):
- Classification module that estimates whether generated images could be considered offensive or harmful.
- Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
- feature_extractor ([`CLIPFeatureExtractor`]):
- Model that extracts features from generated images to be used as inputs for the `safety_checker`.
- """
-
- @classmethod
- def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
- from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
- download_from_original_stable_diffusion_ckpt,
- )
-
- controlnet = kwargs.pop("controlnet", None)
-
- pipe = download_from_original_stable_diffusion_ckpt(
- pretrained_model_link_or_path,
- num_in_channels=9,
- from_safetensors=pretrained_model_link_or_path.endswith("safetensors"),
- device="cpu",
- load_safety_checker=False,
- )
-
- inpaint_pipe = cls(
- vae=pipe.vae,
- text_encoder=pipe.text_encoder,
- tokenizer=pipe.tokenizer,
- unet=pipe.unet,
- controlnet=controlnet,
- scheduler=pipe.scheduler,
- safety_checker=None,
- feature_extractor=None,
- requires_safety_checker=False,
- )
-
- del pipe
- gc.collect()
- return inpaint_pipe
-
- def prepare_mask_latents(
- self,
- mask,
- masked_image,
- batch_size,
- height,
- width,
- dtype,
- device,
- generator,
- do_classifier_free_guidance,
- ):
- # resize the mask to latents shape as we concatenate the mask to the latents
- # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
- # and half precision
- mask = torch.nn.functional.interpolate(
- mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
- )
- mask = mask.to(device=device, dtype=dtype)
-
- masked_image = masked_image.to(device=device, dtype=dtype)
-
- # encode the mask image into latents space so we can concatenate it to the latents
- if isinstance(generator, list):
- masked_image_latents = [
- self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(
- generator=generator[i]
- )
- for i in range(batch_size)
- ]
- masked_image_latents = torch.cat(masked_image_latents, dim=0)
- else:
- masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(
- generator=generator
- )
- masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
-
- # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
- if mask.shape[0] < batch_size:
- if not batch_size % mask.shape[0] == 0:
- raise ValueError(
- "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
- f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
- " of masks that you pass is divisible by the total requested batch size."
- )
- mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
- if masked_image_latents.shape[0] < batch_size:
- if not batch_size % masked_image_latents.shape[0] == 0:
- raise ValueError(
- "The passed images and the required batch size don't match. Images are supposed to be duplicated"
- f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
- " Make sure the number of images that you pass is divisible by the total requested batch size."
- )
- masked_image_latents = masked_image_latents.repeat(
- batch_size // masked_image_latents.shape[0], 1, 1, 1
- )
-
- mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
- masked_image_latents = (
- torch.cat([masked_image_latents] * 2)
- if do_classifier_free_guidance
- else masked_image_latents
- )
-
- # aligning device to prevent device errors when concating it with the latent model input
- masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
- return mask, masked_image_latents
-
- def _default_height_width(self, height, width, image):
- if isinstance(image, list):
- image = image[0]
-
- if height is None:
- if isinstance(image, PIL.Image.Image):
- height = image.height
- elif isinstance(image, torch.Tensor):
- height = image.shape[3]
-
- height = (height // 8) * 8 # round down to nearest multiple of 8
-
- if width is None:
- if isinstance(image, PIL.Image.Image):
- width = image.width
- elif isinstance(image, torch.Tensor):
- width = image.shape[2]
-
- width = (width // 8) * 8 # round down to nearest multiple of 8
-
- return height, width
-
- @torch.no_grad()
- @replace_example_docstring(EXAMPLE_DOC_STRING)
- def __call__(
- self,
- prompt: Union[str, List[str]] = None,
- image: Union[torch.FloatTensor, PIL.Image.Image] = None,
- control_image: Union[
- torch.FloatTensor,
- PIL.Image.Image,
- List[torch.FloatTensor],
- List[PIL.Image.Image],
- ] = None,
- mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
- height: Optional[int] = None,
- width: Optional[int] = None,
- num_inference_steps: int = 50,
- guidance_scale: float = 7.5,
- negative_prompt: Optional[Union[str, List[str]]] = None,
- num_images_per_prompt: Optional[int] = 1,
- eta: float = 0.0,
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
- latents: Optional[torch.FloatTensor] = None,
- prompt_embeds: Optional[torch.FloatTensor] = None,
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
- output_type: Optional[str] = "pil",
- return_dict: bool = True,
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
- callback_steps: int = 1,
- cross_attention_kwargs: Optional[Dict[str, Any]] = None,
- controlnet_conditioning_scale: float = 1.0,
- ):
- r"""
- Function invoked when calling the pipeline for generation.
- Args:
- prompt (`str` or `List[str]`, *optional*):
- The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
- instead.
- image (`PIL.Image.Image`):
- `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
- be masked out with `mask_image` and repainted according to `prompt`.
- control_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
- The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
- the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
- also be accepted as an image. The control image is automatically resized to fit the output image.
- mask_image (`PIL.Image.Image`):
- `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
- repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
- to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
- instead of 3, so the expected shape would be `(B, H, W, 1)`.
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
- The height in pixels of the generated image.
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
- The width in pixels of the generated image.
- num_inference_steps (`int`, *optional*, defaults to 50):
- The number of denoising steps. More denoising steps usually lead to a higher quality image at the
- expense of slower inference.
- guidance_scale (`float`, *optional*, defaults to 7.5):
- Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
- `guidance_scale` is defined as `w` of equation 2. of [Imagen
- Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
- 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
- usually at the expense of lower image quality.
- negative_prompt (`str` or `List[str]`, *optional*):
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
- `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
- Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
- num_images_per_prompt (`int`, *optional*, defaults to 1):
- The number of images to generate per prompt.
- eta (`float`, *optional*, defaults to 0.0):
- Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
- [`schedulers.DDIMScheduler`], will be ignored for others.
- generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
- One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
- to make generation deterministic.
- latents (`torch.FloatTensor`, *optional*):
- Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
- generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
- tensor will ge generated by sampling using the supplied random `generator`.
- prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
- provided, text embeddings will be generated from `prompt` input argument.
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
- argument.
- output_type (`str`, *optional*, defaults to `"pil"`):
- The output format of the generate image. Choose between
- [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
- return_dict (`bool`, *optional*, defaults to `True`):
- Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
- plain tuple.
- callback (`Callable`, *optional*):
- A function that will be called every `callback_steps` steps during inference. The function will be
- called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
- callback_steps (`int`, *optional*, defaults to 1):
- The frequency at which the `callback` function will be called. If not specified, the callback will be
- called at every step.
- cross_attention_kwargs (`dict`, *optional*):
- A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
- `self.processor` in
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
- controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
- The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
- to the residual in the original unet.
- Examples:
- Returns:
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
- [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
- When returning a tuple, the first element is a list with the generated images, and the second element is a
- list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
- (nsfw) content, according to the `safety_checker`.
- """
- # 0. Default height and width to unet
- height, width = self._default_height_width(height, width, control_image)
-
- # 1. Check inputs. Raise error if not correct
- self.check_inputs(
- prompt=prompt,
- image=control_image,
- callback_steps=callback_steps,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- )
-
- # 2. Define call parameters
- if prompt is not None and isinstance(prompt, str):
- batch_size = 1
- elif prompt is not None and isinstance(prompt, list):
- batch_size = len(prompt)
- else:
- batch_size = prompt_embeds.shape[0]
-
- device = self._execution_device
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
- # corresponds to doing no classifier free guidance.
- do_classifier_free_guidance = guidance_scale > 1.0
-
- # 3. Encode input prompt
- prompt_embeds = self._encode_prompt(
- prompt,
- device,
- num_images_per_prompt,
- do_classifier_free_guidance,
- negative_prompt,
- prompt_embeds=prompt_embeds,
- negative_prompt_embeds=negative_prompt_embeds,
- )
-
- # 4. Prepare image
- control_image = self.prepare_image(
- control_image,
- width,
- height,
- batch_size * num_images_per_prompt,
- num_images_per_prompt,
- device,
- self.controlnet.dtype,
- )
-
- if do_classifier_free_guidance:
- control_image = torch.cat([control_image] * 2)
-
- # 5. Prepare timesteps
- self.scheduler.set_timesteps(num_inference_steps, device=device)
- timesteps = self.scheduler.timesteps
-
- # 6. Prepare latent variables
- num_channels_latents = self.controlnet.config.in_channels
- latents = self.prepare_latents(
- batch_size * num_images_per_prompt,
- num_channels_latents,
- height,
- width,
- prompt_embeds.dtype,
- device,
- generator,
- latents,
- )
-
- # EXTRA: prepare mask latents
- mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
- mask, masked_image_latents = self.prepare_mask_latents(
- mask,
- masked_image,
- batch_size * num_images_per_prompt,
- height,
- width,
- prompt_embeds.dtype,
- device,
- generator,
- do_classifier_free_guidance,
- )
-
- # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
-
- # 8. Denoising loop
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
- with self.progress_bar(total=num_inference_steps) as progress_bar:
- for i, t in enumerate(timesteps):
- # expand the latents if we are doing classifier free guidance
- latent_model_input = (
- torch.cat([latents] * 2) if do_classifier_free_guidance else latents
- )
- latent_model_input = self.scheduler.scale_model_input(
- latent_model_input, t
- )
-
- down_block_res_samples, mid_block_res_sample = self.controlnet(
- latent_model_input,
- t,
- encoder_hidden_states=prompt_embeds,
- controlnet_cond=control_image,
- return_dict=False,
- )
-
- down_block_res_samples = [
- down_block_res_sample * controlnet_conditioning_scale
- for down_block_res_sample in down_block_res_samples
- ]
- mid_block_res_sample *= controlnet_conditioning_scale
-
- # predict the noise residual
- latent_model_input = torch.cat(
- [latent_model_input, mask, masked_image_latents], dim=1
- )
- noise_pred = self.unet(
- latent_model_input,
- t,
- encoder_hidden_states=prompt_embeds,
- cross_attention_kwargs=cross_attention_kwargs,
- down_block_additional_residuals=down_block_res_samples,
- mid_block_additional_residual=mid_block_res_sample,
- ).sample
-
- # perform guidance
- if do_classifier_free_guidance:
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
- noise_pred = noise_pred_uncond + guidance_scale * (
- noise_pred_text - noise_pred_uncond
- )
-
- # compute the previous noisy sample x_t -> x_t-1
- latents = self.scheduler.step(
- noise_pred, t, latents, **extra_step_kwargs
- ).prev_sample
-
- # call the callback, if provided
- if i == len(timesteps) - 1 or (
- (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
- ):
- progress_bar.update()
- if callback is not None and i % callback_steps == 0:
- callback(i, t, latents)
-
- # If we do sequential model offloading, let's offload unet and controlnet
- # manually for max memory savings
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
- self.unet.to("cpu")
- self.controlnet.to("cpu")
- torch.cuda.empty_cache()
-
- if output_type == "latent":
- image = latents
- has_nsfw_concept = None
- elif output_type == "pil":
- # 8. Post-processing
- image = self.decode_latents(latents)
-
- # 9. Run safety checker
- image, has_nsfw_concept = self.run_safety_checker(
- image, device, prompt_embeds.dtype
- )
-
- # 10. Convert to PIL
- image = self.numpy_to_pil(image)
- else:
- # 8. Post-processing
- image = self.decode_latents(latents)
-
- # 9. Run safety checker
- image, has_nsfw_concept = self.run_safety_checker(
- image, device, prompt_embeds.dtype
- )
-
- # Offload last model to CPU
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
- self.final_offload_hook.offload()
-
- if not return_dict:
- return (image, has_nsfw_concept)
-
- return StableDiffusionPipelineOutput(
- images=image, nsfw_content_detected=has_nsfw_concept
- )
diff --git a/lama_cleaner/model/power_paint/__init__.py b/lama_cleaner/model/power_paint/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/lama_cleaner/model/power_paint/pipeline_powerpaint.py b/lama_cleaner/model/power_paint/pipeline_powerpaint.py
new file mode 100644
index 0000000..9b7f8c5
--- /dev/null
+++ b/lama_cleaner/model/power_paint/pipeline_powerpaint.py
@@ -0,0 +1,1243 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+
+import numpy as np
+import PIL
+import torch
+from packaging import version
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+from diffusers.configuration_utils import FrozenDict
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import (
+ FromSingleFileMixin,
+ LoraLoaderMixin,
+ TextualInversionLoaderMixin,
+)
+from diffusers.models import (
+ AsymmetricAutoencoderKL,
+ AutoencoderKL,
+ UNet2DConditionModel,
+)
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ deprecate,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+)
+from diffusers.utils.torch_utils import randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import (
+ StableDiffusionSafetyChecker,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def prepare_mask_and_masked_image(
+ image, mask, height, width, return_image: bool = False
+):
+ """
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+ ``image`` and ``1`` for the ``mask``.
+
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ if mask is None:
+ raise ValueError("`mask_image` input cannot be undefined.")
+
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ raise TypeError(
+ f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not"
+ )
+
+ # Batch single image
+ if image.ndim == 3:
+ assert (
+ image.shape[0] == 3
+ ), "Image outside a batch should be of shape (3, H, W)"
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert (
+ image.ndim == 4 and mask.ndim == 4
+ ), "Image and Mask must have 4 dimensions"
+ assert (
+ image.shape[-2:] == mask.shape[-2:]
+ ), "Image and Mask must have the same spatial dimensions"
+ assert (
+ image.shape[0] == mask.shape[0]
+ ), "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(
+ f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not"
+ )
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ # resize all images w.r.t passed height an width
+ image = [
+ i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image
+ ]
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ # preprocess mask
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+ mask = [mask]
+
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+ mask = np.concatenate(
+ [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
+ )
+ mask = mask.astype(np.float32) / 255.0
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ masked_image = image * (mask < 0.5)
+
+ # n.b. ensure backwards compatibility as old function does not return image
+ if return_image:
+ return mask, masked_image, image
+
+ return mask, masked_image
+
+
+class StableDiffusionInpaintPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-guided image inpainting using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+
+ Args:
+ vae ([`AutoencoderKL`, `AsymmetricAutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: Union[AutoencoderKL, AsymmetricAutoencoderKL],
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if (
+ hasattr(scheduler.config, "steps_offset")
+ and scheduler.config.steps_offset != 1
+ ):
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate(
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
+ )
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if (
+ hasattr(scheduler.config, "skip_prk_steps")
+ and scheduler.config.skip_prk_steps is False
+ ):
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration"
+ " `skip_prk_steps`. `skip_prk_steps` should be set to True in the configuration file. Please make"
+ " sure to update the config accordingly as not setting `skip_prk_steps` in the config might lead to"
+ " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging Face"
+ " Hub, it would be very nice if you could open a Pull request for the"
+ " `scheduler/scheduler_config.json` file"
+ )
+ deprecate(
+ "skip_prk_steps not set",
+ "1.0.0",
+ deprecation_message,
+ standard_warn=False,
+ )
+ new_config = dict(scheduler.config)
+ new_config["skip_prk_steps"] = True
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ is_unet_version_less_0_9_0 = hasattr(
+ unet.config, "_diffusers_version"
+ ) and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse(
+ "0.9.0.dev0"
+ )
+ is_unet_sample_size_less_64 = (
+ hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ )
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate(
+ "sample_size<64", "1.0.0", deprecation_message, standard_warn=False
+ )
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ # Check shapes, assume num_channels_latents == 4, num_channels_mask == 1, num_channels_masked == 4
+ if unet.config.in_channels != 9:
+ logger.info(
+ f"You have loaded a UNet with {unet.config.in_channels} input channels which."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offload all models to CPU to reduce memory usage with a low impact on performance. Moves one whole model at a
+ time to the GPU when its `forward` method is called, and the model remains in GPU until the next model runs.
+ Memory savings are lower than using `enable_sequential_cpu_offload`, but performance is much better due to the
+ iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError(
+ "`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher."
+ )
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(
+ cpu_offloaded_model, device, prev_module_hook=hook
+ )
+
+ if self.safety_checker is not None:
+ _, hook = cpu_offload_with_hook(
+ self.safety_checker, device, prev_module_hook=hook
+ )
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ promptA,
+ promptB,
+ t,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_promptA=None,
+ negative_promptB=None,
+ t_nag=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ prompt = promptA
+ negative_prompt = negative_promptA
+
+ if promptA is not None and isinstance(promptA, str):
+ batch_size = 1
+ elif promptA is not None and isinstance(promptA, list):
+ batch_size = len(promptA)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ promptA = self.maybe_convert_prompt(promptA, self.tokenizer)
+
+ text_inputsA = self.tokenizer(
+ promptA,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_inputsB = self.tokenizer(
+ promptB,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_idsA = text_inputsA.input_ids
+ text_input_idsB = text_inputsB.input_ids
+ untruncated_ids = self.tokenizer(
+ promptA, padding="longest", return_tensors="pt"
+ ).input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_idsA.shape[
+ -1
+ ] and not torch.equal(text_input_idsA, untruncated_ids):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if (
+ hasattr(self.text_encoder.config, "use_attention_mask")
+ and self.text_encoder.config.use_attention_mask
+ ):
+ attention_mask = text_inputsA.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ # print("text_input_idsA: ",text_input_idsA)
+ # print("text_input_idsB: ",text_input_idsB)
+ # print('t: ',t)
+
+ prompt_embedsA = self.text_encoder(
+ text_input_idsA.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embedsA = prompt_embedsA[0]
+
+ prompt_embedsB = self.text_encoder(
+ text_input_idsB.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embedsB = prompt_embedsB[0]
+ prompt_embeds = prompt_embedsA * (t) + (1 - t) * prompt_embedsB
+ # print("prompt_embeds: ",prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(
+ bs_embed * num_images_per_prompt, seq_len, -1
+ )
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokensA: List[str]
+ uncond_tokensB: List[str]
+ if negative_prompt is None:
+ uncond_tokensA = [""] * batch_size
+ uncond_tokensB = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokensA = [negative_promptA]
+ uncond_tokensB = [negative_promptB]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokensA = negative_promptA
+ uncond_tokensB = negative_promptB
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokensA = self.maybe_convert_prompt(
+ uncond_tokensA, self.tokenizer
+ )
+ uncond_tokensB = self.maybe_convert_prompt(
+ uncond_tokensB, self.tokenizer
+ )
+
+ max_length = prompt_embeds.shape[1]
+ uncond_inputA = self.tokenizer(
+ uncond_tokensA,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ uncond_inputB = self.tokenizer(
+ uncond_tokensB,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if (
+ hasattr(self.text_encoder.config, "use_attention_mask")
+ and self.text_encoder.config.use_attention_mask
+ ):
+ attention_mask = uncond_inputA.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embedsA = self.text_encoder(
+ uncond_inputA.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embedsB = self.text_encoder(
+ uncond_inputB.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = (
+ negative_prompt_embedsA[0] * (t_nag)
+ + (1 - t_nag) * negative_prompt_embedsB[0]
+ )
+
+ # negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(
+ dtype=prompt_embeds_dtype, device=device
+ )
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
+ 1, num_images_per_prompt, 1
+ )
+ negative_prompt_embeds = negative_prompt_embeds.view(
+ batch_size * num_images_per_prompt, seq_len, -1
+ )
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ # print("prompt_embeds: ",prompt_embeds)
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(
+ image, output_type="pil"
+ )
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(
+ feature_extractor_input, return_tensors="pt"
+ ).to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(
+ inspect.signature(self.scheduler.step).parameters.keys()
+ )
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(
+ inspect.signature(self.scheduler.step).parameters.keys()
+ )
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if strength < 0 or strength > 1:
+ raise ValueError(
+ f"The value of strength should in [0.0, 1.0] but is {strength}"
+ )
+
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
+ )
+
+ if (callback_steps is None) or (
+ callback_steps is not None
+ and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (
+ not isinstance(prompt, str) and not isinstance(prompt, list)
+ ):
+ raise ValueError(
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
+ )
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(
+ self,
+ batch_size,
+ num_channels_latents,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ latents=None,
+ image=None,
+ timestep=None,
+ is_strength_max=True,
+ return_noise=False,
+ return_image_latents=False,
+ ):
+ shape = (
+ batch_size,
+ num_channels_latents,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor,
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if (image is None or timestep is None) and not is_strength_max:
+ raise ValueError(
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
+ "However, either the image or the noise timestep has not been provided."
+ )
+
+ if return_image_latents or (latents is None and not is_strength_max):
+ image = image.to(device=device, dtype=dtype)
+ image_latents = self._encode_vae_image(image=image, generator=generator)
+
+ if latents is None:
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
+ latents = (
+ noise
+ if is_strength_max
+ else self.scheduler.add_noise(image_latents, noise, timestep)
+ )
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
+ latents = (
+ latents * self.scheduler.init_noise_sigma
+ if is_strength_max
+ else latents
+ )
+ else:
+ noise = latents.to(device)
+ latents = noise * self.scheduler.init_noise_sigma
+
+ outputs = (latents,)
+
+ if return_noise:
+ outputs += (noise,)
+
+ if return_image_latents:
+ outputs += (image_latents,)
+
+ return outputs
+
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
+ if isinstance(generator, list):
+ image_latents = [
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(
+ generator=generator[i]
+ )
+ for i in range(image.shape[0])
+ ]
+ image_latents = torch.cat(image_latents, dim=0)
+ else:
+ image_latents = self.vae.encode(image).latent_dist.sample(
+ generator=generator
+ )
+
+ image_latents = self.vae.config.scaling_factor * image_latents
+
+ return image_latents
+
+ def prepare_mask_latents(
+ self,
+ mask,
+ masked_image,
+ batch_size,
+ height,
+ width,
+ dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ ):
+ # resize the mask to latents shape as we concatenate the mask to the latents
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
+ # and half precision
+ mask = torch.nn.functional.interpolate(
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
+ )
+ mask = mask.to(device=device, dtype=dtype)
+
+ masked_image = masked_image.to(device=device, dtype=dtype)
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
+
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
+ if mask.shape[0] < batch_size:
+ if not batch_size % mask.shape[0] == 0:
+ raise ValueError(
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
+ " of masks that you pass is divisible by the total requested batch size."
+ )
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
+ if masked_image_latents.shape[0] < batch_size:
+ if not batch_size % masked_image_latents.shape[0] == 0:
+ raise ValueError(
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
+ )
+ masked_image_latents = masked_image_latents.repeat(
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
+ )
+
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
+ masked_image_latents = (
+ torch.cat([masked_image_latents] * 2)
+ if do_classifier_free_guidance
+ else masked_image_latents
+ )
+
+ # aligning device to prevent device errors when concating it with the latent model input
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
+ return mask, masked_image_latents
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ promptA: Union[str, List[str]] = None,
+ promptB: Union[str, List[str]] = None,
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ strength: float = 1.0,
+ tradoff: float = 1.0,
+ tradoff_nag: float = 1.0,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_promptA: Optional[Union[str, List[str]]] = None,
+ negative_promptB: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ task_class: Union[torch.Tensor, float, int] = None,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ image (`PIL.Image.Image`):
+ `Image` or tensor representing an image batch to be inpainted (which parts of the image to be masked
+ out with `mask_image` and repainted according to `prompt`).
+ mask_image (`PIL.Image.Image`):
+ `Image` or tensor representing an image batch to mask `image`. White pixels in the mask are repainted
+ while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel
+ (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the
+ expected shape would be `(B, H, W, 1)`.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ strength (`float`, *optional*, defaults to 1.0):
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
+ essentially ignores `image`.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference. This parameter is modulated by `strength`.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+
+ Examples:
+
+ ```py
+ >>> import PIL
+ >>> import requests
+ >>> import torch
+ >>> from io import BytesIO
+
+ >>> from diffusers import StableDiffusionInpaintPipeline
+
+
+ >>> def download_image(url):
+ ... response = requests.get(url)
+ ... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
+
+
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
+
+ >>> init_image = download_image(img_url).resize((512, 512))
+ >>> mask_image = download_image(mask_url).resize((512, 512))
+
+ >>> pipe = StableDiffusionInpaintPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
+ ... )
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
+ >>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
+ ```
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+ prompt = promptA
+ negative_prompt = negative_promptA
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ height,
+ width,
+ strength,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ cross_attention_kwargs.get("scale", None)
+ if cross_attention_kwargs is not None
+ else None
+ )
+ prompt_embeds = self._encode_prompt(
+ promptA,
+ promptB,
+ tradoff,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_promptA,
+ negative_promptB,
+ tradoff_nag,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ )
+
+ # 4. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps, num_inference_steps = self.get_timesteps(
+ num_inference_steps=num_inference_steps, strength=strength, device=device
+ )
+ # check that number of inference steps is not < 1 - as this doesn't make sense
+ if num_inference_steps < 1:
+ raise ValueError(
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
+ )
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
+ is_strength_max = strength == 1.0
+
+ # 5. Preprocess mask and image
+ mask, masked_image, init_image = prepare_mask_and_masked_image(
+ image, mask_image, height, width, return_image=True
+ )
+ mask_condition = mask.clone()
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.vae.config.latent_channels
+ num_channels_unet = self.unet.config.in_channels
+ return_image_latents = num_channels_unet == 4
+
+ latents_outputs = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ image=init_image,
+ timestep=latent_timestep,
+ is_strength_max=is_strength_max,
+ return_noise=True,
+ return_image_latents=return_image_latents,
+ )
+
+ if return_image_latents:
+ latents, noise, image_latents = latents_outputs
+ else:
+ latents, noise = latents_outputs
+
+ # 7. Prepare mask latent variables
+ mask, masked_image_latents = self.prepare_mask_latents(
+ mask,
+ masked_image,
+ batch_size * num_images_per_prompt,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ do_classifier_free_guidance,
+ )
+
+ # 8. Check that sizes of mask, masked image and latents match
+ if num_channels_unet == 9:
+ # default case for runwayml/stable-diffusion-inpainting
+ num_channels_mask = mask.shape[1]
+ num_channels_masked_image = masked_image_latents.shape[1]
+ if (
+ num_channels_latents + num_channels_mask + num_channels_masked_image
+ != self.unet.config.in_channels
+ ):
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
+ " `pipeline.unet` or your `mask_image` or `image` input."
+ )
+ elif num_channels_unet != 4:
+ raise ValueError(
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
+ )
+
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 10. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = (
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ )
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(
+ latent_model_input, t
+ )
+
+ if num_channels_unet == 9:
+ latent_model_input = torch.cat(
+ [latent_model_input, mask, masked_image_latents], dim=1
+ )
+
+ # predict the noise residual
+ if task_class is not None:
+ noise_pred = self.unet(
+ sample=latent_model_input,
+ timestep=t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ task_class=task_class,
+ )[0]
+ else:
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (
+ noise_pred_text - noise_pred_uncond
+ )
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
+ )[0]
+
+ if num_channels_unet == 4:
+ init_latents_proper = image_latents[:1]
+ init_mask = mask[:1]
+
+ if i < len(timesteps) - 1:
+ noise_timestep = timesteps[i + 1]
+ init_latents_proper = self.scheduler.add_noise(
+ init_latents_proper, noise, torch.tensor([noise_timestep])
+ )
+
+ latents = (
+ 1 - init_mask
+ ) * init_latents_proper + init_mask * latents
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or (
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
+ ):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ if not output_type == "latent":
+ condition_kwargs = {}
+ if isinstance(self.vae, AsymmetricAutoencoderKL):
+ init_image = init_image.to(
+ device=device, dtype=masked_image_latents.dtype
+ )
+ init_image_condition = init_image.clone()
+ init_image = self._encode_vae_image(init_image, generator=generator)
+ mask_condition = mask_condition.to(
+ device=device, dtype=masked_image_latents.dtype
+ )
+ condition_kwargs = {
+ "image": init_image_condition,
+ "mask": mask_condition,
+ }
+ image = self.vae.decode(
+ latents / self.vae.config.scaling_factor,
+ return_dict=False,
+ **condition_kwargs,
+ )[0]
+ image, has_nsfw_concept = self.run_safety_checker(
+ image, device, prompt_embeds.dtype
+ )
+ else:
+ image = latents
+ has_nsfw_concept = None
+
+ if has_nsfw_concept is None:
+ do_denormalize = [True] * image.shape[0]
+ else:
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
+
+ image = self.image_processor.postprocess(
+ image, output_type=output_type, do_denormalize=do_denormalize
+ )
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, has_nsfw_concept)
+
+ return StableDiffusionPipelineOutput(
+ images=image, nsfw_content_detected=has_nsfw_concept
+ )
diff --git a/lama_cleaner/model/power_paint/pipeline_powerpaint_controlnet.py b/lama_cleaner/model/power_paint/pipeline_powerpaint_controlnet.py
new file mode 100644
index 0000000..cba0f8f
--- /dev/null
+++ b/lama_cleaner/model/power_paint/pipeline_powerpaint_controlnet.py
@@ -0,0 +1,1775 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/
+
+import inspect
+import warnings
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+
+import numpy as np
+import PIL.Image
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.image_processor import VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ replace_example_docstring,
+)
+from diffusers.utils.torch_utils import randn_tensor,is_compiled_module
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers.pipelines.controlnet import MultiControlNetModel
+
+
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> # !pip install transformers accelerate
+ >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
+ >>> from diffusers.utils import load_image
+ >>> import numpy as np
+ >>> import torch
+
+ >>> init_image = load_image(
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
+ ... )
+ >>> init_image = init_image.resize((512, 512))
+
+ >>> generator = torch.Generator(device="cpu").manual_seed(1)
+
+ >>> mask_image = load_image(
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png"
+ ... )
+ >>> mask_image = mask_image.resize((512, 512))
+
+
+ >>> def make_inpaint_condition(image, image_mask):
+ ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
+ ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
+
+ ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
+ ... image[image_mask > 0.5] = -1.0 # set as masked pixel
+ ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
+ ... image = torch.from_numpy(image)
+ ... return image
+
+
+ >>> control_image = make_inpaint_condition(init_image, mask_image)
+
+ >>> controlnet = ControlNetModel.from_pretrained(
+ ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
+ ... )
+ >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
+ ... )
+
+ >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
+ >>> pipe.enable_model_cpu_offload()
+
+ >>> # generate image
+ >>> image = pipe(
+ ... "a handsome man with ray-ban sunglasses",
+ ... num_inference_steps=20,
+ ... generator=generator,
+ ... eta=1.0,
+ ... image=init_image,
+ ... mask_image=mask_image,
+ ... control_image=control_image,
+ ... ).images[0]
+ ```
+"""
+
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image
+def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
+ """
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
+ ``image`` and ``1`` for the ``mask``.
+
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
+
+ Args:
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
+
+
+ Raises:
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
+ (ot the other way around).
+
+ Returns:
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
+ dimensions: ``batch x channels x height x width``.
+ """
+
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ if mask is None:
+ raise ValueError("`mask_image` input cannot be undefined.")
+
+ if isinstance(image, torch.Tensor):
+ if not isinstance(mask, torch.Tensor):
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
+
+ # Batch single image
+ if image.ndim == 3:
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
+ image = image.unsqueeze(0)
+
+ # Batch and add channel dim for single mask
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0).unsqueeze(0)
+
+ # Batch single mask or add channel dim
+ if mask.ndim == 3:
+ # Single batched mask, no channel dim or single mask not batched but channel dim
+ if mask.shape[0] == 1:
+ mask = mask.unsqueeze(0)
+
+ # Batched masks no channel dim
+ else:
+ mask = mask.unsqueeze(1)
+
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
+
+ # Check image is in [-1, 1]
+ if image.min() < -1 or image.max() > 1:
+ raise ValueError("Image should be in [-1, 1] range")
+
+ # Check mask is in [0, 1]
+ if mask.min() < 0 or mask.max() > 1:
+ raise ValueError("Mask should be in [0, 1] range")
+
+ # Binarize mask
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+
+ # Image as float32
+ image = image.to(dtype=torch.float32)
+ elif isinstance(mask, torch.Tensor):
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
+ else:
+ # preprocess image
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
+ image = [image]
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
+ # resize all images w.r.t passed height an width
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
+ image = np.concatenate([i[None, :] for i in image], axis=0)
+
+ image = image.transpose(0, 3, 1, 2)
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
+
+ # preprocess mask
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
+ mask = [mask]
+
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
+ mask = mask.astype(np.float32) / 255.0
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
+
+ mask[mask < 0.5] = 0
+ mask[mask >= 0.5] = 1
+ mask = torch.from_numpy(mask)
+
+ masked_image = image * (mask < 0.5)
+
+ # n.b. ensure backwards compatibility as old function does not return image
+ if return_image:
+ return mask, masked_image, image
+
+ return mask, masked_image
+
+
+class StableDiffusionControlNetInpaintPipeline(
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
+):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ In addition the pipeline inherits the following loading methods:
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
+
+