From 0a262fa81105934d8b235322e312b2237ce4029b Mon Sep 17 00:00:00 2001 From: Qing Date: Fri, 12 Apr 2024 11:07:41 +0800 Subject: [PATCH] make brushnet work --- iopaint/const.py | 6 +- iopaint/download.py | 11 +- iopaint/model/brushnet/brushnet.py | 931 ++++++++++++ .../model/brushnet/brushnet_unet_forward.py | 322 +++++ iopaint/model/brushnet/brushnet_wrapper.py | 157 ++ iopaint/model/brushnet/pipeline_brushnet.py | 1279 +++++++++++++++++ iopaint/model/brushnet/unet_2d_blocks.py | 388 +++++ iopaint/model_manager.py | 102 +- iopaint/schema.py | 79 +- iopaint/tests/test_brushnet.py | 89 ++ .../components/SidePanel/DiffusionOptions.tsx | 79 + web_app/src/lib/api.ts | 3 + web_app/src/lib/states.ts | 16 +- web_app/src/lib/types.ts | 2 + 14 files changed, 3408 insertions(+), 56 deletions(-) create mode 100644 iopaint/model/brushnet/brushnet.py create mode 100644 iopaint/model/brushnet/brushnet_unet_forward.py create mode 100644 iopaint/model/brushnet/brushnet_wrapper.py create mode 100644 iopaint/model/brushnet/pipeline_brushnet.py create mode 100644 iopaint/model/brushnet/unet_2d_blocks.py create mode 100644 iopaint/tests/test_brushnet.py diff --git a/iopaint/const.py b/iopaint/const.py index 148cb77..9a0386b 100644 --- a/iopaint/const.py +++ b/iopaint/const.py @@ -6,7 +6,6 @@ KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint" POWERPAINT_NAME = "Sanster/PowerPaint-V1-stable-diffusion-inpainting" ANYTEXT_NAME = "Sanster/AnyText" - DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline" DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline" DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline" @@ -62,6 +61,11 @@ SD_CONTROLNET_CHOICES: List[str] = [ "lllyasviel/control_v11f1p_sd15_depth", ] +SD_BRUSHNET_CHOICES: List[str] = [ + "Sanster/brushnet_random_mask", + "Sanster/brushnet_segmentation_mask" +] + SD2_CONTROLNET_CHOICES = [ "thibaud/controlnet-sd21-canny-diffusers", "thibaud/controlnet-sd21-depth-diffusers", diff --git a/iopaint/download.py b/iopaint/download.py index 51fd84f..a587d90 100644 --- a/iopaint/download.py +++ b/iopaint/download.py @@ -1,3 +1,4 @@ +import glob import json import os from functools import lru_cache @@ -92,7 +93,7 @@ def get_sdxl_model_type(model_abs_path: str) -> ModelType: else: model_type = ModelType.DIFFUSERS_SDXL except ValueError as e: - if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e): + if "but got torch.Size([320, 4, 3, 3])" in str(e): model_type = ModelType.DIFFUSERS_SDXL else: raise e @@ -192,7 +193,9 @@ def scan_diffusers_models() -> List[ModelInfo]: cache_dir = Path(HF_HUB_CACHE) # logger.info(f"Scanning diffusers models in {cache_dir}") diffusers_model_names = [] - for it in cache_dir.glob("**/*/model_index.json"): + model_index_files = glob.glob(os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True) + for it in model_index_files: + it = Path(it) with open(it, "r", encoding="utf-8") as f: try: data = json.load(f) @@ -238,7 +241,9 @@ def _scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]: cache_dir = Path(cache_dir) available_models = [] diffusers_model_names = [] - for it in cache_dir.glob("**/*/model_index.json"): + model_index_files = glob.glob(os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True) + for it in model_index_files: + it = Path(it) with open(it, "r", encoding="utf-8") as f: try: data = json.load(f) diff --git a/iopaint/model/brushnet/brushnet.py b/iopaint/model/brushnet/brushnet.py new file mode 100644 index 0000000..b3a045b --- /dev/null +++ b/iopaint/model/brushnet/brushnet.py @@ -0,0 +1,931 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, \ + TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unets.unet_2d_blocks import ( + CrossAttnDownBlock2D, + DownBlock2D, get_down_block, get_up_block, +) + +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from .unet_2d_blocks import MidBlock2D + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class BrushNetOutput(BaseOutput): + """ + The output of [`BrushNetModel`]. + + Args: + up_block_res_samples (`tuple[torch.Tensor]`): + A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's upsampling activations. + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the midde block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + + up_block_res_samples: Tuple[torch.Tensor] + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class BrushNetModel(ModelMixin, ConfigMixin): + """ + A BrushNet model. + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + cross_attention_dim (`int`, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): + The dimension of the attention heads. + use_linear_projection (`bool`, defaults to `False`): + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + num_class_embeds (`int`, *optional*, defaults to 0): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + upcast_attention (`bool`, defaults to `False`): + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): + The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when + `class_embed_type="projection"`. + brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + global_pool_conditions (`bool`, defaults to `False`): + TODO(Patrick) - unused parameter. + addition_embed_type_num_heads (`int`, defaults to 64): + The number of heads to use for the `TextTimeEmbedding` layer. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 5, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str, ...] = ( + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2D", + up_block_types: Tuple[str, ...] = ( + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + "UpBlock2D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int, ...]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + brushnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + addition_embed_type_num_heads: int = 64, + ): + super().__init__() + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in_condition = nn.Conv2d( + in_channels + conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, + padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + self.down_blocks = nn.ModuleList([]) + self.brushnet_down_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + brushnet_block = zero_module(brushnet_block) + self.brushnet_down_blocks.append(brushnet_block) + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + brushnet_block = zero_module(brushnet_block) + self.brushnet_down_blocks.append(brushnet_block) + + if not is_final_block: + brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + brushnet_block = zero_module(brushnet_block) + self.brushnet_down_blocks.append(brushnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) + brushnet_block = zero_module(brushnet_block) + self.brushnet_mid_block = brushnet_block + + self.mid_block = MidBlock2D( + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + dropout=0.0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block))) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + + self.up_blocks = nn.ModuleList([]) + self.brushnet_up_blocks = nn.ModuleList([]) + + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=reversed_num_attention_heads[i], + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + ) + + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + for _ in range(layers_per_block + 1): + brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + brushnet_block = zero_module(brushnet_block) + self.brushnet_up_blocks.append(brushnet_block) + + if not is_final_block: + brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + brushnet_block = zero_module(brushnet_block) + self.brushnet_up_blocks.append(brushnet_block) + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + brushnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + conditioning_channels: int = 5, + ): + r""" + Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied + where applicable. + """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None + encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None + addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None + addition_time_embed_dim = ( + unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None + ) + + brushnet = cls( + in_channels=unet.config.in_channels, + conditioning_channels=conditioning_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + down_block_types=['DownBlock2D', 'DownBlock2D', 'DownBlock2D', 'DownBlock2D'], + mid_block_type='MidBlock2D', + up_block_types=['UpBlock2D', 'UpBlock2D', 'UpBlock2D', 'UpBlock2D'], + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + encoder_hid_dim=encoder_hid_dim, + encoder_hid_dim_type=encoder_hid_dim_type, + attention_head_dim=unet.config.attention_head_dim, + num_attention_heads=unet.config.num_attention_heads, + use_linear_projection=unet.config.use_linear_projection, + class_embed_type=unet.config.class_embed_type, + addition_embed_type=addition_embed_type, + addition_time_embed_dim=addition_time_embed_dim, + num_class_embeds=unet.config.num_class_embeds, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, + brushnet_conditioning_channel_order=brushnet_conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + ) + + if load_weights_from_unet: + conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight) + conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight + conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight + brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight) + brushnet.conv_in_condition.bias = unet.conv_in.bias + + brushnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + if brushnet.class_embedding: + brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) + + brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False) + brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False) + brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False) + + return brushnet + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + brushnet_cond: torch.FloatTensor, + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]: + """ + The [`BrushNetModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + brushnet_cond (`torch.FloatTensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for BrushNet outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + guess_mode (`bool`, defaults to `False`): + In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if + you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple. + + Returns: + [`~models.brushnet.BrushNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is + returned where the first element is the sample tensor. + """ + # check channel order + channel_order = self.config.brushnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + brushnet_cond = torch.flip(brushnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + if self.config.addition_embed_type is not None: + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + + # 2. pre-process + brushnet_cond = torch.concat([sample, brushnet_cond], 1) + sample = self.conv_in_condition(brushnet_cond) + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. PaintingNet down blocks + brushnet_down_block_res_samples = () + for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks): + down_block_res_sample = brushnet_down_block(down_block_res_sample) + brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,) + + # 5. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = self.mid_block(sample, emb) + + # 6. BrushNet mid blocks + brushnet_mid_block_res_sample = self.brushnet_mid_block(sample) + + # 7. up + up_block_res_samples = () + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets):] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample, up_res_samples = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + return_res_samples=True + ) + else: + sample, up_res_samples = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + return_res_samples=True + ) + + up_block_res_samples += up_res_samples + + # 8. BrushNet up blocks + brushnet_up_block_res_samples = () + for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks): + up_block_res_sample = brushnet_up_block(up_block_res_sample) + brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace(-1, 0, + len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples), + device=sample.device) # 0.1 to 1.0 + scales = scales * conditioning_scale + + brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples, + scales[:len( + brushnet_down_block_res_samples)])] + brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)] + brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples, + scales[ + len(brushnet_down_block_res_samples) + 1:])] + else: + brushnet_down_block_res_samples = [sample * conditioning_scale for sample in + brushnet_down_block_res_samples] + brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale + brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples] + + if self.config.global_pool_conditions: + brushnet_down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples + ] + brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True) + brushnet_up_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples + ] + + if not return_dict: + return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples) + + return BrushNetOutput( + down_block_res_samples=brushnet_down_block_res_samples, + mid_block_res_sample=brushnet_mid_block_res_sample, + up_block_res_samples=brushnet_up_block_res_samples + ) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + + +if __name__ == "__main__": + BrushNetModel.from_pretrained("/Users/cwq/data/models/brushnet/brushnet_random_mask", variant='fp16', + use_safetensors=True) diff --git a/iopaint/model/brushnet/brushnet_unet_forward.py b/iopaint/model/brushnet/brushnet_unet_forward.py new file mode 100644 index 0000000..04e8f0a --- /dev/null +++ b/iopaint/model/brushnet/brushnet_unet_forward.py @@ -0,0 +1,322 @@ +from typing import Union, Optional, Dict, Any, Tuple + +import torch +from diffusers.models.unet_2d_condition import UNet2DConditionOutput +from diffusers.utils import USE_PEFT_BACKEND, unscale_lora_layers, deprecate, scale_lora_layers + + +def brushnet_unet_forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + down_block_add_samples: Optional[Tuple[torch.Tensor]] = None, + mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None, + up_block_add_samples: Optional[Tuple[torch.Tensor]] = None, +) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added to UNet long skip connections from down blocks to up blocks for + example from ControlNet side model(s) + mid_block_additional_residual (`torch.Tensor`, *optional*): + additional residual to be added to UNet mid block output, for example from ControlNet side model + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2 ** self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + t_emb = self.get_time_embed(sample=sample, timestep=timestep) + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) + if class_emb is not None: + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + aug_emb = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + if self.config.addition_embed_type == "image_hint": + aug_emb, hint = aug_emb + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + + if is_brushnet: + sample = sample + down_block_add_samples.pop(0) + + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + if is_brushnet and len(down_block_add_samples) > 0: + additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0) + for _ in range( + len(downsample_block.resnets) + (downsample_block.downsamplers != None))] + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + additional_residuals = {} + if is_brushnet and len(down_block_add_samples) > 0: + additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0) + for _ in range( + len(downsample_block.resnets) + (downsample_block.downsamplers != None))] + + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale, + **additional_residuals) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + if is_brushnet: + sample = sample + mid_block_add_sample + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets):] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + additional_residuals = {} + if is_brushnet and len(up_block_add_samples) > 0: + additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0) + for _ in range( + len(upsample_block.resnets) + (upsample_block.upsamplers != None))] + + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + additional_residuals = {} + if is_brushnet and len(up_block_add_samples) > 0: + additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0) + for _ in range( + len(upsample_block.resnets) + (upsample_block.upsamplers != None))] + + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, + **additional_residuals, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/iopaint/model/brushnet/brushnet_wrapper.py b/iopaint/model/brushnet/brushnet_wrapper.py new file mode 100644 index 0000000..69bd484 --- /dev/null +++ b/iopaint/model/brushnet/brushnet_wrapper.py @@ -0,0 +1,157 @@ +import PIL.Image +import cv2 +import torch +from loguru import logger +import numpy as np + +from ..base import DiffusionInpaintModel +from ..helper.cpu_text_encoder import CPUTextEncoderWrapper +from ..original_sd_configs import get_config_files +from ..utils import ( + handle_from_pretrained_exceptions, + get_torch_dtype, + enable_low_mem, + is_local_files_only, +) +from .brushnet import BrushNetModel +from .brushnet_unet_forward import brushnet_unet_forward +from .unet_2d_blocks import CrossAttnDownBlock2D_forward, DownBlock2D_forward, CrossAttnUpBlock2D_forward, \ + UpBlock2D_forward +from ...schema import InpaintRequest, ModelType + + +class BrushNetWrapper(DiffusionInpaintModel): + pad_mod = 8 + min_size = 512 + + def init_model(self, device: torch.device, **kwargs): + from .pipeline_brushnet import StableDiffusionBrushNetPipeline + self.model_info = kwargs["model_info"] + self.brushnet_method = kwargs["brushnet_method"] + + use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) + self.torch_dtype = torch_dtype + + model_kwargs = { + **kwargs.get("pipe_components", {}), + "local_files_only": is_local_files_only(**kwargs), + } + self.local_files_only = model_kwargs["local_files_only"] + + disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get( + "cpu_offload", False + ) + if disable_nsfw_checker: + logger.info("Disable Stable Diffusion Model NSFW checker") + model_kwargs.update( + dict( + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + ) + + logger.info(f"Loading BrushNet model from {self.brushnet_method}") + brushnet = BrushNetModel.from_pretrained(self.brushnet_method, torch_dtype=torch_dtype) + + if self.model_info.is_single_file_diffusers: + if self.model_info.model_type == ModelType.DIFFUSERS_SD: + model_kwargs["num_in_channels"] = 4 + else: + model_kwargs["num_in_channels"] = 9 + + self.model = StableDiffusionBrushNetPipeline.from_single_file( + self.model_id_or_path, + torch_dtype=torch_dtype, + load_safety_checker=not disable_nsfw_checker, + config_files=get_config_files(), + brushnet=brushnet, + **model_kwargs, + ) + else: + self.model = handle_from_pretrained_exceptions( + StableDiffusionBrushNetPipeline.from_pretrained, + pretrained_model_name_or_path=self.model_id_or_path, + variant="fp16", + torch_dtype=torch_dtype, + brushnet=brushnet, + **model_kwargs, + ) + + enable_low_mem(self.model, kwargs.get("low_mem", False)) + + if kwargs.get("cpu_offload", False) and use_gpu: + logger.info("Enable sequential cpu offload") + self.model.enable_sequential_cpu_offload(gpu_id=0) + else: + self.model = self.model.to(device) + if kwargs["sd_cpu_textencoder"]: + logger.info("Run Stable Diffusion TextEncoder on CPU") + self.model.text_encoder = CPUTextEncoderWrapper( + self.model.text_encoder, torch_dtype + ) + + self.callback = kwargs.pop("callback", None) + + # Monkey patch the forward method of the UNet to use the brushnet_unet_forward method + self.model.unet.forward = brushnet_unet_forward.__get__(self.model.unet, self.model.unet.__class__) + + for down_block in self.model.brushnet.down_blocks: + down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__) + for up_block in self.model.brushnet.up_blocks: + up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__) + + # Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward + for down_block in self.model.unet.down_blocks: + if down_block.__class__.__name__ == "CrossAttnDownBlock2D": + down_block.forward = CrossAttnDownBlock2D_forward.__get__(down_block, down_block.__class__) + else: + down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__) + + for up_block in self.model.unet.up_blocks: + if up_block.__class__.__name__ == "CrossAttnUpBlock2D": + up_block.forward = CrossAttnUpBlock2D_forward.__get__(up_block, up_block.__class__) + else: + up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__) + + def switch_brushnet_method(self, new_method: str): + self.brushnet_method = new_method + brushnet = BrushNetModel.from_pretrained( + new_method, + resume_download=True, + local_files_only=self.local_files_only, + torch_dtype=self.torch_dtype, + ).to(self.model.device) + self.model.brushnet = brushnet + + def forward(self, image, mask, config: InpaintRequest): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + """ + self.set_scheduler(config) + + img_h, img_w = image.shape[:2] + normalized_mask = mask[:, :].astype("float32") / 255.0 + image = image * (1 - normalized_mask) + image = image.astype(np.uint8) + output = self.model( + image=PIL.Image.fromarray(image), + prompt=config.prompt, + negative_prompt=config.negative_prompt, + mask=PIL.Image.fromarray(mask[:, :, -1], mode="L").convert("RGB"), + num_inference_steps=config.sd_steps, + # strength=config.sd_strength, + guidance_scale=config.sd_guidance_scale, + output_type="np", + callback_on_step_end=self.callback, + height=img_h, + width=img_w, + generator=torch.manual_seed(config.sd_seed), + brushnet_conditioning_scale=config.brushnet_conditioning_scale, + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output diff --git a/iopaint/model/brushnet/pipeline_brushnet.py b/iopaint/model/brushnet/pipeline_brushnet.py new file mode 100644 index 0000000..2826e77 --- /dev/null +++ b/iopaint/model/brushnet/pipeline_brushnet.py @@ -0,0 +1,1279 @@ +# https://github.com/TencentARC/BrushNet +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + +from .brushnet import BrushNetModel + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + from diffusers import StableDiffusionBrushNetPipeline, BrushNetModel, UniPCMultistepScheduler + from diffusers.utils import load_image + import torch + import cv2 + import numpy as np + from PIL import Image + + base_model_path = "runwayml/stable-diffusion-v1-5" + brushnet_path = "ckpt_path" + + brushnet = BrushNetModel.from_pretrained(brushnet_path, torch_dtype=torch.float16) + pipe = StableDiffusionBrushNetPipeline.from_pretrained( + base_model_path, brushnet=brushnet, torch_dtype=torch.float16, low_cpu_mem_usage=False + ) + + # speed up diffusion process with faster scheduler and memory optimization + pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + # remove following line if xformers is not installed or when using Torch 2.0. + # pipe.enable_xformers_memory_efficient_attention() + # memory optimization. + pipe.enable_model_cpu_offload() + + image_path="examples/brushnet/src/test_image.jpg" + mask_path="examples/brushnet/src/test_mask.jpg" + caption="A cake on the table." + + init_image = cv2.imread(image_path) + mask_image = 1.*(cv2.imread(mask_path).sum(-1)>255)[:,:,np.newaxis] + init_image = init_image * (1-mask_image) + + init_image = Image.fromarray(init_image.astype(np.uint8)).convert("RGB") + mask_image = Image.fromarray(mask_image.astype(np.uint8).repeat(3,-1)*255).convert("RGB") + + generator = torch.Generator("cuda").manual_seed(1234) + + image = pipe( + caption, + init_image, + mask_image, + num_inference_steps=50, + generator=generator, + paintingnet_conditioning_scale=1.0 + ).images[0] + image.save("output.png") + ``` +""" + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class StableDiffusionBrushNetPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + LoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion with BrushNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + brushnet ([`BrushNetModel`]`): + Provides additional conditioning to the `unet` during the denoising process. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + brushnet: BrushNetModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + brushnet=brushnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1: -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) + + if do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + repeat_dims = [1] + image_embeds = [] + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + else: + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + image_embeds.append(single_image_embeds) + + return image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + mask, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + brushnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.brushnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.brushnet, BrushNetModel) + or is_compiled + and isinstance(self.brushnet._orig_mod, BrushNetModel) + ): + self.check_image(image, mask, prompt, prompt_embeds) + else: + assert False + + # Check `brushnet_conditioning_scale` + if ( + isinstance(self.brushnet, BrushNetModel) + or is_compiled + and isinstance(self.brushnet._orig_mod, BrushNetModel) + ): + if not isinstance(brushnet_conditioning_scale, float): + raise TypeError("For single brushnet: `brushnet_conditioning_scale` must be type `float`.") + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def check_image(self, image, mask, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + mask_is_pil = isinstance(mask, PIL.Image.Image) + mask_is_tensor = isinstance(mask, torch.Tensor) + mask_is_np = isinstance(mask, np.ndarray) + mask_is_pil_list = isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image) + mask_is_tensor_list = isinstance(mask, list) and isinstance(mask[0], torch.Tensor) + mask_is_np_list = isinstance(mask, list) and isinstance(mask[0], np.ndarray) + + if ( + not mask_is_pil + and not mask_is_tensor + and not mask_is_np + and not mask_is_pil_list + and not mask_is_tensor_list + and not mask_is_np_list + ): + raise TypeError( + f"mask must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(mask)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + noise = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = noise * self.scheduler.init_noise_sigma + return latents, noise + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: PipelineImageInput = None, + mask: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + brushnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The BrushNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.FloatTensor`, it is passed to BrushNet as is. `PIL.Image.Image` can also be + accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized accordingly. If multiple BrushNets are specified in + `init`, images must be passed as a list such that each element of the list can be correctly batched for + input to a single BrushNet. When `prompt` is a list, and if a list of images is passed for a single BrushNet, + each will be paired with each prompt in the `prompt` list. This also applies to multiple BrushNets, + where a list of image lists can be passed to batch for each prompt and each BrushNet. + mask (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The BrushNet input condition to provide guidance to the `unet` for generation. If the type is + specified as `torch.FloatTensor`, it is passed to BrushNet as is. `PIL.Image.Image` can also be + accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height + and/or width are passed, `image` is resized accordingly. If multiple BrushNets are specified in + `init`, images must be passed as a list such that each element of the list can be correctly batched for + input to a single BrushNet. When `prompt` is a list, and if a list of images is passed for a single BrushNet, + each will be paired with each prompt in the `prompt` list. This also applies to multiple BrushNets, + where a list of image lists can be passed to batch for each prompt and each BrushNet. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. + Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding + if `do_classifier_free_guidance` is set to `True`. + If not provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + brushnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the BrushNet are multiplied by `brushnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple BrushNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The BrushNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the BrushNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the BrushNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + brushnet = self.brushnet._orig_mod if is_compiled_module(self.brushnet) else self.brushnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + control_guidance_start, control_guidance_end = ( + [control_guidance_start], + [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + mask, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + brushnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + global_pool_conditions = ( + brushnet.config.global_pool_conditions + if isinstance(brushnet, BrushNetModel) + else brushnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image + if isinstance(brushnet, BrushNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=brushnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + original_mask = self.prepare_image( + image=mask, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=brushnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + original_mask = (original_mask.sum(1)[:, None, :, :] < 0).to(image.dtype) + height, width = image.shape[-2:] + else: + assert False + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents, noise = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.1 prepare condition latents + conditioning_latents = self.vae.encode(image).latent_dist.sample() * self.vae.config.scaling_factor + mask = torch.nn.functional.interpolate( + original_mask, + size=( + conditioning_latents.shape[-2], + conditioning_latents.shape[-1] + ) + ) + conditioning_latents = torch.concat([conditioning_latents, mask], 1) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.2 Create tensor stating which brushnets to keep + brushnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + brushnet_keep.append(keeps[0] if isinstance(brushnet, BrushNetModel) else keeps) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_brushnet_compiled = is_compiled_module(self.brushnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_brushnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # brushnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer BrushNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + brushnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + brushnet_prompt_embeds = prompt_embeds + + if isinstance(brushnet_keep[i], list): + cond_scale = [c * s for c, s in zip(brushnet_conditioning_scale, brushnet_keep[i])] + else: + brushnet_cond_scale = brushnet_conditioning_scale + if isinstance(brushnet_cond_scale, list): + brushnet_cond_scale = brushnet_cond_scale[0] + cond_scale = brushnet_cond_scale * brushnet_keep[i] + + down_block_res_samples, mid_block_res_sample, up_block_res_samples = self.brushnet( + control_model_input, + t, + encoder_hidden_states=brushnet_prompt_embeds, + brushnet_cond=conditioning_latents, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Infered BrushNet only for the conditional batch. + # To apply the output of BrushNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + up_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in up_block_res_samples] + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_add_samples=down_block_res_samples, + mid_block_add_sample=mid_block_res_sample, + up_block_add_samples=up_block_res_samples, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # If we do sequential model offloading, let's offload unet and brushnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.brushnet.to("cpu") + torch.cuda.empty_cache() + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/iopaint/model/brushnet/unet_2d_blocks.py b/iopaint/model/brushnet/unet_2d_blocks.py new file mode 100644 index 0000000..dcaae8e --- /dev/null +++ b/iopaint/model/brushnet/unet_2d_blocks.py @@ -0,0 +1,388 @@ +from typing import Dict, Any, Optional, Tuple + +import torch +from diffusers.models.resnet import ResnetBlock2D +from diffusers.utils import is_torch_version +from diffusers.utils.torch_utils import apply_freeu +from torch import nn + + +class MidBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + use_linear_projection: bool = False, + ): + super().__init__() + + self.has_cross_attention = False + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + + for i in range(num_layers): + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + lora_scale = 1.0 + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + for resnet in self.resnets[1:]: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + + return hidden_states + + +def DownBlock2D_forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, + down_block_add_samples: Optional[torch.FloatTensor] = None, +) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + + if down_block_add_samples is not None: + hidden_states = hidden_states + down_block_add_samples.pop(0) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=scale) + + if down_block_add_samples is not None: + hidden_states = hidden_states + down_block_add_samples.pop(0) # todo: add before or after + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +def CrossAttnDownBlock2D_forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals: Optional[torch.FloatTensor] = None, + down_block_add_samples: Optional[torch.FloatTensor] = None, +) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + if down_block_add_samples is not None: + hidden_states = hidden_states + down_block_add_samples.pop(0) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=lora_scale) + + if down_block_add_samples is not None: + hidden_states = hidden_states + down_block_add_samples.pop(0) # todo: add before or after + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +def CrossAttnUpBlock2D_forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + return_res_samples: Optional[bool] = False, + up_block_add_samples: Optional[torch.FloatTensor] = None, +) -> torch.FloatTensor: + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + if return_res_samples: + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if return_res_samples: + output_states = output_states + (hidden_states,) + if up_block_add_samples is not None: + hidden_states = hidden_states + up_block_add_samples.pop(0) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) + if return_res_samples: + output_states = output_states + (hidden_states,) + if up_block_add_samples is not None: + hidden_states = hidden_states + up_block_add_samples.pop(0) + + if return_res_samples: + return hidden_states, output_states + else: + return hidden_states + + +def UpBlock2D_forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + scale: float = 1.0, + return_res_samples: Optional[bool] = False, + up_block_add_samples: Optional[torch.FloatTensor] = None, +) -> torch.FloatTensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + if return_res_samples: + output_states = () + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + + if return_res_samples: + output_states = output_states + (hidden_states,) + if up_block_add_samples is not None: + hidden_states = hidden_states + up_block_add_samples.pop(0) # todo: add before or after + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=scale) + + if return_res_samples: + output_states = output_states + (hidden_states,) + if up_block_add_samples is not None: + hidden_states = hidden_states + up_block_add_samples.pop(0) # todo: add before or after + + if return_res_samples: + return hidden_states, output_states + else: + return hidden_states diff --git a/iopaint/model_manager.py b/iopaint/model_manager.py index 15cf32b..fdd2c46 100644 --- a/iopaint/model_manager.py +++ b/iopaint/model_manager.py @@ -7,6 +7,7 @@ import numpy as np from iopaint.download import scan_models from iopaint.helper import switch_mps_device from iopaint.model import models, ControlNet, SD, SDXL +from iopaint.model.brushnet.brushnet_wrapper import BrushNetWrapper from iopaint.model.utils import torch_gc, is_local_files_only from iopaint.schema import InpaintRequest, ModelInfo, ModelType @@ -22,12 +23,16 @@ class ModelManager: self.enable_controlnet = kwargs.get("enable_controlnet", False) controlnet_method = kwargs.get("controlnet_method", None) if ( - controlnet_method is None - and name in self.available_models - and self.available_models[name].support_controlnet + controlnet_method is None + and name in self.available_models + and self.available_models[name].support_controlnet ): controlnet_method = self.available_models[name].controlnets[0] self.controlnet_method = controlnet_method + + self.enable_brushnet = kwargs.get("enable_brushnet", False) + self.brushnet_method = kwargs.get("brushnet_method", None) + self.model = self.init_model(name, device, **kwargs) @property @@ -47,24 +52,30 @@ class ModelManager: "model_info": model_info, "enable_controlnet": self.enable_controlnet, "controlnet_method": self.controlnet_method, + "enable_brushnet": self.enable_brushnet, + "brushnet_method": self.brushnet_method, } if model_info.support_controlnet and self.enable_controlnet: return ControlNet(device, **kwargs) - elif model_info.name in models: - return models[name](device, **kwargs) - else: - if model_info.model_type in [ - ModelType.DIFFUSERS_SD_INPAINT, - ModelType.DIFFUSERS_SD, - ]: - return SD(device, **kwargs) - if model_info.model_type in [ - ModelType.DIFFUSERS_SDXL_INPAINT, - ModelType.DIFFUSERS_SDXL, - ]: - return SDXL(device, **kwargs) + if model_info.support_brushnet and self.enable_brushnet: + return BrushNetWrapper(device, **kwargs) + + if model_info.name in models: + return models[name](device, **kwargs) + + if model_info.model_type in [ + ModelType.DIFFUSERS_SD_INPAINT, + ModelType.DIFFUSERS_SD, + ]: + return SD(device, **kwargs) + + if model_info.model_type in [ + ModelType.DIFFUSERS_SDXL_INPAINT, + ModelType.DIFFUSERS_SDXL, + ]: + return SDXL(device, **kwargs) raise NotImplementedError(f"Unsupported model: {name}") @@ -80,7 +91,10 @@ class ModelManager: Returns: BGR image """ - self.switch_controlnet_method(config) + if not config.enable_brushnet: + self.switch_controlnet_method(config) + if not config.enable_controlnet: + self.switch_brushnet_method(config) self.enable_disable_freeu(config) self.enable_disable_lcm_lora(config) return self.model(image, mask, config).astype(np.uint8) @@ -99,9 +113,9 @@ class ModelManager: self.name = new_name if ( - self.available_models[new_name].support_controlnet - and self.controlnet_method - not in self.available_models[new_name].controlnets + self.available_models[new_name].support_controlnet + and self.controlnet_method + not in self.available_models[new_name].controlnets ): self.controlnet_method = self.available_models[new_name].controlnets[0] try: @@ -121,14 +135,54 @@ class ModelManager: ) raise e + def switch_brushnet_method(self, config): + if not self.available_models[self.name].support_brushnet: + return + + if ( + self.enable_brushnet + and config.brushnet_method + and self.brushnet_method != config.brushnet_method + ): + old_brushnet_method = self.brushnet_method + self.brushnet_method = config.brushnet_method + self.model.switch_brushnet_method(config.brushnet_method) + logger.info( + f"Switch Brushnet method from {old_brushnet_method} to {config.brushnet_method}" + ) + + elif self.enable_brushnet != config.enable_brushnet: + self.enable_brushnet = config.enable_brushnet + self.brushnet_method = config.brushnet_method + + pipe_components = { + "vae": self.model.model.vae, + "text_encoder": self.model.model.text_encoder, + "unet": self.model.model.unet, + } + if hasattr(self.model.model, "text_encoder_2"): + pipe_components["text_encoder_2"] = self.model.model.text_encoder_2 + + self.model = self.init_model( + self.name, + switch_mps_device(self.name, self.device), + pipe_components=pipe_components, + **self.kwargs, + ) + + if not config.enable_brushnet: + logger.info("BrushNet Disabled") + else: + logger.info("BrushNet Enabled") + def switch_controlnet_method(self, config): if not self.available_models[self.name].support_controlnet: return if ( - self.enable_controlnet - and config.controlnet_method - and self.controlnet_method != config.controlnet_method + self.enable_controlnet + and config.controlnet_method + and self.controlnet_method != config.controlnet_method ): old_controlnet_method = self.controlnet_method self.controlnet_method = config.controlnet_method @@ -155,7 +209,7 @@ class ModelManager: **self.kwargs, ) if not config.enable_controlnet: - logger.info(f"Disable controlnet") + logger.info("Disable controlnet") else: logger.info(f"Enable controlnet: {config.controlnet_method}") diff --git a/iopaint/schema.py b/iopaint/schema.py index c8ba9ca..24c115b 100644 --- a/iopaint/schema.py +++ b/iopaint/schema.py @@ -3,6 +3,8 @@ from enum import Enum from pathlib import Path from typing import Optional, Literal, List +from loguru import logger + from iopaint.const import ( INSTRUCT_PIX2PIX_NAME, KANDINSKY22_NAME, @@ -11,9 +13,9 @@ from iopaint.const import ( SDXL_CONTROLNET_CHOICES, SD2_CONTROLNET_CHOICES, SD_CONTROLNET_CHOICES, + SD_BRUSHNET_CHOICES, ) -from loguru import logger -from pydantic import BaseModel, Field, field_validator, computed_field +from pydantic import BaseModel, Field, computed_field, model_validator class ModelType(str, Enum): @@ -63,6 +65,13 @@ class ModelInfo(BaseModel): return SD_CONTROLNET_CHOICES return [] + @computed_field + @property + def brushnets(self) -> List[str]: + if self.model_type in [ModelType.DIFFUSERS_SD]: + return SD_BRUSHNET_CHOICES + return [] + @computed_field @property def support_strength(self) -> bool: @@ -103,6 +112,13 @@ class ModelInfo(BaseModel): ModelType.DIFFUSERS_SDXL_INPAINT, ] + @computed_field + @property + def support_brushnet(self) -> bool: + return self.model_type in [ + ModelType.DIFFUSERS_SD, + ] + @computed_field @property def support_freeu(self) -> bool: @@ -369,6 +385,13 @@ class InpaintRequest(BaseModel): "lllyasviel/control_v11p_sd15_canny", description="Controlnet method" ) + # BrushNet + enable_brushnet: bool = Field(False, description="Enable brushnet") + brushnet_method: str = Field( + SD_BRUSHNET_CHOICES[0], description="Brushnet method" + ) + brushnet_conditioning_scale: float = Field(1.0, description="brushnet conditioning scale", ge=0.0, le=1.0) + # PowerPaint powerpaint_task: PowerPaintTask = Field( PowerPaintTask.text_guided, description="PowerPaint task" @@ -380,31 +403,37 @@ class InpaintRequest(BaseModel): le=1.0, ) - @field_validator("sd_seed") - @classmethod - def sd_seed_validator(cls, v: int) -> int: - if v == -1: - return random.randint(1, 99999999) - return v + @model_validator(mode='after') + def validate_field(cls, values: 'InpaintRequest'): + if values.sd_seed == -1: + values.sd_seed = random.randint(1, 99999999) + logger.info(f"Generate random seed: {values.sd_seed}") - @field_validator("controlnet_conditioning_scale") - @classmethod - def validate_field(cls, v: float, values): - use_extender = values.data["use_extender"] - enable_controlnet = values.data["enable_controlnet"] - if use_extender and enable_controlnet: - logger.info(f"Extender is enabled, set controlnet_conditioning_scale=0") - return 0 - return v + if values.use_extender and values.enable_controlnet: + logger.info("Extender is enabled, set controlnet_conditioning_scale=0") + values.controlnet_conditioning_scale = 0 - @field_validator("sd_strength") - @classmethod - def validate_sd_strength(cls, v: float, values): - use_extender = values.data["use_extender"] - if use_extender: - logger.info(f"Extender is enabled, set sd_strength=1") - return 1.0 - return v + if values.use_extender: + logger.info("Extender is enabled, set sd_strength=1") + values.sd_strength = 1.0 + + if values.enable_brushnet: + logger.info("BrushNet is enabled, set enable_controlnet=False") + if values.enable_controlnet: + values.enable_controlnet = False + if values.sd_lcm_lora: + logger.info("BrushNet is enabled, set sd_lcm_lora=False") + values.sd_lcm_lora = False + if values.sd_freeu: + logger.info("BrushNet is enabled, set sd_freeu=False") + values.sd_freeu = False + + if values.enable_controlnet: + logger.info("ControlNet is enabled, set enable_brushnet=False") + if values.enable_brushnet: + values.enable_brushnet = False + + return values class RunPluginRequest(BaseModel): diff --git a/iopaint/tests/test_brushnet.py b/iopaint/tests/test_brushnet.py new file mode 100644 index 0000000..e46d194 --- /dev/null +++ b/iopaint/tests/test_brushnet.py @@ -0,0 +1,89 @@ +import os + +from iopaint.const import SD_BRUSHNET_CHOICES +from iopaint.tests.utils import check_device, get_config, assert_equal + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +from pathlib import Path + +import pytest +import torch + +from iopaint.model_manager import ModelManager +from iopaint.schema import HDStrategy, SDSampler, FREEUConfig + +current_dir = Path(__file__).parent.absolute().resolve() +save_dir = current_dir / "result" +save_dir.mkdir(exist_ok=True, parents=True) + + +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m_karras]) +def test_runway_brushnet(device, sampler): + sd_steps = check_device(device) + model = ModelManager( + name="runwayml/stable-diffusion-v1-5", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + ) + cfg = get_config( + strategy=HDStrategy.ORIGINAL, + prompt="face of a fox, sitting on a bench", + sd_steps=sd_steps, + sd_guidance_scale=7.5, + sd_freeu=True, + sd_freeu_config=FREEUConfig(), + enable_brushnet=True, + brushnet_method=SD_BRUSHNET_CHOICES[0] + ) + cfg.sd_sampler = sampler + + assert_equal( + model, + cfg, + f"brushnet_runway_1_5_freeu_device_{device}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("device", ["cuda", "mps"]) +@pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m_karras]) +@pytest.mark.parametrize( + "name", + [ + "v1-5-pruned-emaonly.safetensors", + ], +) +def test_brushnet_local_file_path(device, sampler, name): + sd_steps = check_device(device) + model = ModelManager( + name=name, + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + cpu_offload=False, + ) + cfg = get_config( + strategy=HDStrategy.ORIGINAL, + prompt="face of a fox, sitting on a bench", + sd_steps=sd_steps, + sd_seed=1234, + enable_brushnet=True, + brushnet_method=SD_BRUSHNET_CHOICES[1] + ) + cfg.sd_sampler = sampler + name = f"device_{device}_{sampler}_{name}" + + is_sdxl = "sd_xl" in name + + assert_equal( + model, + cfg, + f"brushnet_sd_local_model_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=1.5 if is_sdxl else 1, + fy=1.5 if is_sdxl else 1, + ) diff --git a/web_app/src/components/SidePanel/DiffusionOptions.tsx b/web_app/src/components/SidePanel/DiffusionOptions.tsx index c5702d2..4999ec0 100644 --- a/web_app/src/components/SidePanel/DiffusionOptions.tsx +++ b/web_app/src/components/SidePanel/DiffusionOptions.tsx @@ -109,6 +109,84 @@ const DiffusionOptions = () => { ) } + const renderBrushNetSetting = () => { + if (!settings.model.support_brushnet) { + return null + } + + return ( +
+
+
+ + { + updateSettings({ enableBrushNet: value }) + }} + /> +
+
+ + + updateSettings({ brushnetConditioningScale: vals[0] / 100 }) + } + /> + { + updateSettings({ brushnetConditioningScale: val }) + }} + /> + +
+ +
+ +
+
+ +
+ ) + } + const renderConterNetSetting = () => { if (!settings.model.support_controlnet) { return null @@ -881,6 +959,7 @@ const DiffusionOptions = () => { {renderSeed()} {renderNegativePrompt()} + {renderBrushNetSetting()} {renderConterNetSetting()} {renderLCMLora()} {renderMaskBlur()} diff --git a/web_app/src/lib/api.ts b/web_app/src/lib/api.ts index ce2900f..21e72dd 100644 --- a/web_app/src/lib/api.ts +++ b/web_app/src/lib/api.ts @@ -78,6 +78,9 @@ export default async function inpaint( controlnet_method: settings.controlnetMethod ? settings.controlnetMethod : "", + enable_brushnet: settings.enableBrushNet, + brushnet_method: settings.brushnetMethod ? settings.brushnetMethod : "", + brushnet_conditioning_scale: settings.brushnetConditioningScale, powerpaint_task: settings.showExtender ? PowerPaintTask.outpainting : settings.powerpaintTask, diff --git a/web_app/src/lib/states.ts b/web_app/src/lib/states.ts index faba1ba..aa3adf2 100644 --- a/web_app/src/lib/states.ts +++ b/web_app/src/lib/states.ts @@ -99,6 +99,11 @@ export type Settings = { controlnetConditioningScale: number controlnetMethod: string + // BrushNet + enableBrushNet: boolean + brushnetMethod: string + brushnetConditioningScale: number + enableLCMLora: boolean enableFreeu: boolean freeuConfig: FreeuConfig @@ -306,15 +311,16 @@ const defaultValues: AppState = { path: "lama", model_type: "inpaint", support_controlnet: false, + support_brushnet: false, support_strength: false, support_outpainting: false, controlnets: [], + brushnets: [], support_freeu: false, support_lcm_lora: false, is_single_file_diffusers: false, need_prompt: false, }, - enableControlnet: false, showCropper: false, showExtender: false, extenderDirection: ExtenderDirection.xy, @@ -339,8 +345,12 @@ const defaultValues: AppState = { sdMatchHistograms: false, sdScale: 1.0, p2pImageGuidanceScale: 1.5, - controlnetConditioningScale: 0.4, + enableControlnet: false, controlnetMethod: "lllyasviel/control_v11p_sd15_canny", + controlnetConditioningScale: 0.4, + enableBrushNet: false, + brushnetMethod: "random_mask", + brushnetConditioningScale: 1.0, enableLCMLora: false, enableFreeu: false, freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 }, @@ -1076,7 +1086,7 @@ export const useStore = createWithEqualityFn()( })), { name: "ZUSTAND_STATE", // name of the item in the storage (must be unique) - version: 1, + version: 2, partialize: (state) => Object.fromEntries( Object.entries(state).filter(([key]) => diff --git a/web_app/src/lib/types.ts b/web_app/src/lib/types.ts index 38cacc0..70e436b 100644 --- a/web_app/src/lib/types.ts +++ b/web_app/src/lib/types.ts @@ -48,7 +48,9 @@ export interface ModelInfo { support_strength: boolean support_outpainting: boolean support_controlnet: boolean + support_brushnet: boolean controlnets: string[] + brushnets: string[] support_freeu: boolean support_lcm_lora: boolean need_prompt: boolean