make brushnet work
This commit is contained in:
parent
35f12d5b9b
commit
0a262fa811
@ -6,7 +6,6 @@ KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
|||||||
POWERPAINT_NAME = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
|
POWERPAINT_NAME = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
|
||||||
ANYTEXT_NAME = "Sanster/AnyText"
|
ANYTEXT_NAME = "Sanster/AnyText"
|
||||||
|
|
||||||
|
|
||||||
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
|
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
|
||||||
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
|
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
|
||||||
DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
|
DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
|
||||||
@ -62,6 +61,11 @@ SD_CONTROLNET_CHOICES: List[str] = [
|
|||||||
"lllyasviel/control_v11f1p_sd15_depth",
|
"lllyasviel/control_v11f1p_sd15_depth",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
SD_BRUSHNET_CHOICES: List[str] = [
|
||||||
|
"Sanster/brushnet_random_mask",
|
||||||
|
"Sanster/brushnet_segmentation_mask"
|
||||||
|
]
|
||||||
|
|
||||||
SD2_CONTROLNET_CHOICES = [
|
SD2_CONTROLNET_CHOICES = [
|
||||||
"thibaud/controlnet-sd21-canny-diffusers",
|
"thibaud/controlnet-sd21-canny-diffusers",
|
||||||
"thibaud/controlnet-sd21-depth-diffusers",
|
"thibaud/controlnet-sd21-depth-diffusers",
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import glob
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
@ -92,7 +93,7 @@ def get_sdxl_model_type(model_abs_path: str) -> ModelType:
|
|||||||
else:
|
else:
|
||||||
model_type = ModelType.DIFFUSERS_SDXL
|
model_type = ModelType.DIFFUSERS_SDXL
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
|
if "but got torch.Size([320, 4, 3, 3])" in str(e):
|
||||||
model_type = ModelType.DIFFUSERS_SDXL
|
model_type = ModelType.DIFFUSERS_SDXL
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
@ -192,7 +193,9 @@ def scan_diffusers_models() -> List[ModelInfo]:
|
|||||||
cache_dir = Path(HF_HUB_CACHE)
|
cache_dir = Path(HF_HUB_CACHE)
|
||||||
# logger.info(f"Scanning diffusers models in {cache_dir}")
|
# logger.info(f"Scanning diffusers models in {cache_dir}")
|
||||||
diffusers_model_names = []
|
diffusers_model_names = []
|
||||||
for it in cache_dir.glob("**/*/model_index.json"):
|
model_index_files = glob.glob(os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True)
|
||||||
|
for it in model_index_files:
|
||||||
|
it = Path(it)
|
||||||
with open(it, "r", encoding="utf-8") as f:
|
with open(it, "r", encoding="utf-8") as f:
|
||||||
try:
|
try:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
@ -238,7 +241,9 @@ def _scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
|
|||||||
cache_dir = Path(cache_dir)
|
cache_dir = Path(cache_dir)
|
||||||
available_models = []
|
available_models = []
|
||||||
diffusers_model_names = []
|
diffusers_model_names = []
|
||||||
for it in cache_dir.glob("**/*/model_index.json"):
|
model_index_files = glob.glob(os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True)
|
||||||
|
for it in model_index_files:
|
||||||
|
it = Path(it)
|
||||||
with open(it, "r", encoding="utf-8") as f:
|
with open(it, "r", encoding="utf-8") as f:
|
||||||
try:
|
try:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
931
iopaint/model/brushnet/brushnet.py
Normal file
931
iopaint/model/brushnet/brushnet.py
Normal file
@ -0,0 +1,931 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from diffusers.utils import BaseOutput, logging
|
||||||
|
from diffusers.models.attention_processor import (
|
||||||
|
ADDED_KV_ATTENTION_PROCESSORS,
|
||||||
|
CROSS_ATTENTION_PROCESSORS,
|
||||||
|
AttentionProcessor,
|
||||||
|
AttnAddedKVProcessor,
|
||||||
|
AttnProcessor,
|
||||||
|
)
|
||||||
|
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, \
|
||||||
|
TimestepEmbedding, Timesteps
|
||||||
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
|
from diffusers.models.unets.unet_2d_blocks import (
|
||||||
|
CrossAttnDownBlock2D,
|
||||||
|
DownBlock2D, get_down_block, get_up_block,
|
||||||
|
)
|
||||||
|
|
||||||
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
|
from .unet_2d_blocks import MidBlock2D
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BrushNetOutput(BaseOutput):
|
||||||
|
"""
|
||||||
|
The output of [`BrushNetModel`].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
up_block_res_samples (`tuple[torch.Tensor]`):
|
||||||
|
A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
|
||||||
|
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
||||||
|
used to condition the original UNet's upsampling activations.
|
||||||
|
down_block_res_samples (`tuple[torch.Tensor]`):
|
||||||
|
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
||||||
|
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
||||||
|
used to condition the original UNet's downsampling activations.
|
||||||
|
mid_down_block_re_sample (`torch.Tensor`):
|
||||||
|
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
||||||
|
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
||||||
|
Output can be used to condition the original UNet's middle block activation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
up_block_res_samples: Tuple[torch.Tensor]
|
||||||
|
down_block_res_samples: Tuple[torch.Tensor]
|
||||||
|
mid_block_res_sample: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class BrushNetModel(ModelMixin, ConfigMixin):
|
||||||
|
"""
|
||||||
|
A BrushNet model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (`int`, defaults to 4):
|
||||||
|
The number of channels in the input sample.
|
||||||
|
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||||
|
Whether to flip the sin to cos in the time embedding.
|
||||||
|
freq_shift (`int`, defaults to 0):
|
||||||
|
The frequency shift to apply to the time embedding.
|
||||||
|
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||||
|
The tuple of downsample blocks to use.
|
||||||
|
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
||||||
|
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
||||||
|
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
||||||
|
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
||||||
|
The tuple of upsample blocks to use.
|
||||||
|
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||||
|
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||||
|
The tuple of output channels for each block.
|
||||||
|
layers_per_block (`int`, defaults to 2):
|
||||||
|
The number of layers per block.
|
||||||
|
downsample_padding (`int`, defaults to 1):
|
||||||
|
The padding to use for the downsampling convolution.
|
||||||
|
mid_block_scale_factor (`float`, defaults to 1):
|
||||||
|
The scale factor to use for the mid block.
|
||||||
|
act_fn (`str`, defaults to "silu"):
|
||||||
|
The activation function to use.
|
||||||
|
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||||
|
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
||||||
|
in post-processing.
|
||||||
|
norm_eps (`float`, defaults to 1e-5):
|
||||||
|
The epsilon to use for the normalization.
|
||||||
|
cross_attention_dim (`int`, defaults to 1280):
|
||||||
|
The dimension of the cross attention features.
|
||||||
|
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||||
|
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||||
|
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||||
|
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||||
|
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||||
|
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||||
|
dimension to `cross_attention_dim`.
|
||||||
|
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||||
|
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||||
|
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||||
|
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||||
|
The dimension of the attention heads.
|
||||||
|
use_linear_projection (`bool`, defaults to `False`):
|
||||||
|
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||||
|
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||||
|
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||||
|
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||||
|
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||||
|
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||||
|
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||||
|
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||||
|
class conditioning with `class_embed_type` equal to `None`.
|
||||||
|
upcast_attention (`bool`, defaults to `False`):
|
||||||
|
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||||
|
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||||
|
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
||||||
|
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
||||||
|
`class_embed_type="projection"`.
|
||||||
|
brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
||||||
|
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
||||||
|
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
||||||
|
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||||
|
global_pool_conditions (`bool`, defaults to `False`):
|
||||||
|
TODO(Patrick) - unused parameter.
|
||||||
|
addition_embed_type_num_heads (`int`, defaults to 64):
|
||||||
|
The number of heads to use for the `TextTimeEmbedding` layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 4,
|
||||||
|
conditioning_channels: int = 5,
|
||||||
|
flip_sin_to_cos: bool = True,
|
||||||
|
freq_shift: int = 0,
|
||||||
|
down_block_types: Tuple[str, ...] = (
|
||||||
|
"DownBlock2D",
|
||||||
|
"DownBlock2D",
|
||||||
|
"DownBlock2D",
|
||||||
|
"DownBlock2D",
|
||||||
|
),
|
||||||
|
mid_block_type: Optional[str] = "UNetMidBlock2D",
|
||||||
|
up_block_types: Tuple[str, ...] = (
|
||||||
|
"UpBlock2D",
|
||||||
|
"UpBlock2D",
|
||||||
|
"UpBlock2D",
|
||||||
|
"UpBlock2D",
|
||||||
|
),
|
||||||
|
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||||
|
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||||
|
layers_per_block: int = 2,
|
||||||
|
downsample_padding: int = 1,
|
||||||
|
mid_block_scale_factor: float = 1,
|
||||||
|
act_fn: str = "silu",
|
||||||
|
norm_num_groups: Optional[int] = 32,
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
cross_attention_dim: int = 1280,
|
||||||
|
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||||
|
encoder_hid_dim: Optional[int] = None,
|
||||||
|
encoder_hid_dim_type: Optional[str] = None,
|
||||||
|
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||||
|
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||||
|
use_linear_projection: bool = False,
|
||||||
|
class_embed_type: Optional[str] = None,
|
||||||
|
addition_embed_type: Optional[str] = None,
|
||||||
|
addition_time_embed_dim: Optional[int] = None,
|
||||||
|
num_class_embeds: Optional[int] = None,
|
||||||
|
upcast_attention: bool = False,
|
||||||
|
resnet_time_scale_shift: str = "default",
|
||||||
|
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||||
|
brushnet_conditioning_channel_order: str = "rgb",
|
||||||
|
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||||
|
global_pool_conditions: bool = False,
|
||||||
|
addition_embed_type_num_heads: int = 64,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||||
|
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||||
|
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||||
|
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||||
|
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||||
|
# which is why we correct for the naming here.
|
||||||
|
num_attention_heads = num_attention_heads or attention_head_dim
|
||||||
|
|
||||||
|
# Check inputs
|
||||||
|
if len(down_block_types) != len(up_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(block_out_channels) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(transformer_layers_per_block, int):
|
||||||
|
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||||
|
|
||||||
|
# input
|
||||||
|
conv_in_kernel = 3
|
||||||
|
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||||
|
self.conv_in_condition = nn.Conv2d(
|
||||||
|
in_channels + conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel,
|
||||||
|
padding=conv_in_padding
|
||||||
|
)
|
||||||
|
|
||||||
|
# time
|
||||||
|
time_embed_dim = block_out_channels[0] * 4
|
||||||
|
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||||
|
timestep_input_dim = block_out_channels[0]
|
||||||
|
self.time_embedding = TimestepEmbedding(
|
||||||
|
timestep_input_dim,
|
||||||
|
time_embed_dim,
|
||||||
|
act_fn=act_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
||||||
|
encoder_hid_dim_type = "text_proj"
|
||||||
|
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
||||||
|
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
||||||
|
|
||||||
|
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if encoder_hid_dim_type == "text_proj":
|
||||||
|
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
||||||
|
elif encoder_hid_dim_type == "text_image_proj":
|
||||||
|
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||||
|
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||||
|
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
||||||
|
self.encoder_hid_proj = TextImageProjection(
|
||||||
|
text_embed_dim=encoder_hid_dim,
|
||||||
|
image_embed_dim=cross_attention_dim,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif encoder_hid_dim_type is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.encoder_hid_proj = None
|
||||||
|
|
||||||
|
# class embedding
|
||||||
|
if class_embed_type is None and num_class_embeds is not None:
|
||||||
|
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||||
|
elif class_embed_type == "timestep":
|
||||||
|
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||||
|
elif class_embed_type == "identity":
|
||||||
|
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||||
|
elif class_embed_type == "projection":
|
||||||
|
if projection_class_embeddings_input_dim is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
||||||
|
)
|
||||||
|
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
||||||
|
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
||||||
|
# 2. it projects from an arbitrary input dimension.
|
||||||
|
#
|
||||||
|
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
||||||
|
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
||||||
|
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
||||||
|
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||||
|
else:
|
||||||
|
self.class_embedding = None
|
||||||
|
|
||||||
|
if addition_embed_type == "text":
|
||||||
|
if encoder_hid_dim is not None:
|
||||||
|
text_time_embedding_from_dim = encoder_hid_dim
|
||||||
|
else:
|
||||||
|
text_time_embedding_from_dim = cross_attention_dim
|
||||||
|
|
||||||
|
self.add_embedding = TextTimeEmbedding(
|
||||||
|
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
||||||
|
)
|
||||||
|
elif addition_embed_type == "text_image":
|
||||||
|
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||||
|
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||||
|
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
||||||
|
self.add_embedding = TextImageTimeEmbedding(
|
||||||
|
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
||||||
|
)
|
||||||
|
elif addition_embed_type == "text_time":
|
||||||
|
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
||||||
|
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||||
|
|
||||||
|
elif addition_embed_type is not None:
|
||||||
|
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
||||||
|
|
||||||
|
self.down_blocks = nn.ModuleList([])
|
||||||
|
self.brushnet_down_blocks = nn.ModuleList([])
|
||||||
|
|
||||||
|
if isinstance(only_cross_attention, bool):
|
||||||
|
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||||
|
|
||||||
|
if isinstance(attention_head_dim, int):
|
||||||
|
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||||
|
|
||||||
|
if isinstance(num_attention_heads, int):
|
||||||
|
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||||
|
|
||||||
|
# down
|
||||||
|
output_channel = block_out_channels[0]
|
||||||
|
|
||||||
|
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
brushnet_block = zero_module(brushnet_block)
|
||||||
|
self.brushnet_down_blocks.append(brushnet_block)
|
||||||
|
|
||||||
|
for i, down_block_type in enumerate(down_block_types):
|
||||||
|
input_channel = output_channel
|
||||||
|
output_channel = block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
|
||||||
|
down_block = get_down_block(
|
||||||
|
down_block_type,
|
||||||
|
num_layers=layers_per_block,
|
||||||
|
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
add_downsample=not is_final_block,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
num_attention_heads=num_attention_heads[i],
|
||||||
|
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||||
|
downsample_padding=downsample_padding,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
only_cross_attention=only_cross_attention[i],
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_blocks.append(down_block)
|
||||||
|
|
||||||
|
for _ in range(layers_per_block):
|
||||||
|
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
brushnet_block = zero_module(brushnet_block)
|
||||||
|
self.brushnet_down_blocks.append(brushnet_block)
|
||||||
|
|
||||||
|
if not is_final_block:
|
||||||
|
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
brushnet_block = zero_module(brushnet_block)
|
||||||
|
self.brushnet_down_blocks.append(brushnet_block)
|
||||||
|
|
||||||
|
# mid
|
||||||
|
mid_block_channel = block_out_channels[-1]
|
||||||
|
|
||||||
|
brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
||||||
|
brushnet_block = zero_module(brushnet_block)
|
||||||
|
self.brushnet_mid_block = brushnet_block
|
||||||
|
|
||||||
|
self.mid_block = MidBlock2D(
|
||||||
|
in_channels=mid_block_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=0.0,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
output_scale_factor=mid_block_scale_factor,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
)
|
||||||
|
|
||||||
|
# count how many layers upsample the images
|
||||||
|
self.num_upsamplers = 0
|
||||||
|
|
||||||
|
# up
|
||||||
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||||
|
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||||
|
reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
|
||||||
|
only_cross_attention = list(reversed(only_cross_attention))
|
||||||
|
|
||||||
|
output_channel = reversed_block_out_channels[0]
|
||||||
|
|
||||||
|
self.up_blocks = nn.ModuleList([])
|
||||||
|
self.brushnet_up_blocks = nn.ModuleList([])
|
||||||
|
|
||||||
|
for i, up_block_type in enumerate(up_block_types):
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
output_channel = reversed_block_out_channels[i]
|
||||||
|
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||||
|
|
||||||
|
# add upsample block for all BUT final layer
|
||||||
|
if not is_final_block:
|
||||||
|
add_upsample = True
|
||||||
|
self.num_upsamplers += 1
|
||||||
|
else:
|
||||||
|
add_upsample = False
|
||||||
|
|
||||||
|
up_block = get_up_block(
|
||||||
|
up_block_type,
|
||||||
|
num_layers=layers_per_block + 1,
|
||||||
|
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
prev_output_channel=prev_output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
add_upsample=add_upsample,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resolution_idx=i,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
num_attention_heads=reversed_num_attention_heads[i],
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
only_cross_attention=only_cross_attention[i],
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.up_blocks.append(up_block)
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
|
||||||
|
for _ in range(layers_per_block + 1):
|
||||||
|
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
brushnet_block = zero_module(brushnet_block)
|
||||||
|
self.brushnet_up_blocks.append(brushnet_block)
|
||||||
|
|
||||||
|
if not is_final_block:
|
||||||
|
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
brushnet_block = zero_module(brushnet_block)
|
||||||
|
self.brushnet_up_blocks.append(brushnet_block)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_unet(
|
||||||
|
cls,
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
brushnet_conditioning_channel_order: str = "rgb",
|
||||||
|
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||||
|
load_weights_from_unet: bool = True,
|
||||||
|
conditioning_channels: int = 5,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
unet (`UNet2DConditionModel`):
|
||||||
|
The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
|
||||||
|
where applicable.
|
||||||
|
"""
|
||||||
|
transformer_layers_per_block = (
|
||||||
|
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
||||||
|
)
|
||||||
|
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
||||||
|
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
||||||
|
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
||||||
|
addition_time_embed_dim = (
|
||||||
|
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
||||||
|
)
|
||||||
|
|
||||||
|
brushnet = cls(
|
||||||
|
in_channels=unet.config.in_channels,
|
||||||
|
conditioning_channels=conditioning_channels,
|
||||||
|
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||||
|
freq_shift=unet.config.freq_shift,
|
||||||
|
down_block_types=['DownBlock2D', 'DownBlock2D', 'DownBlock2D', 'DownBlock2D'],
|
||||||
|
mid_block_type='MidBlock2D',
|
||||||
|
up_block_types=['UpBlock2D', 'UpBlock2D', 'UpBlock2D', 'UpBlock2D'],
|
||||||
|
only_cross_attention=unet.config.only_cross_attention,
|
||||||
|
block_out_channels=unet.config.block_out_channels,
|
||||||
|
layers_per_block=unet.config.layers_per_block,
|
||||||
|
downsample_padding=unet.config.downsample_padding,
|
||||||
|
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
||||||
|
act_fn=unet.config.act_fn,
|
||||||
|
norm_num_groups=unet.config.norm_num_groups,
|
||||||
|
norm_eps=unet.config.norm_eps,
|
||||||
|
cross_attention_dim=unet.config.cross_attention_dim,
|
||||||
|
transformer_layers_per_block=transformer_layers_per_block,
|
||||||
|
encoder_hid_dim=encoder_hid_dim,
|
||||||
|
encoder_hid_dim_type=encoder_hid_dim_type,
|
||||||
|
attention_head_dim=unet.config.attention_head_dim,
|
||||||
|
num_attention_heads=unet.config.num_attention_heads,
|
||||||
|
use_linear_projection=unet.config.use_linear_projection,
|
||||||
|
class_embed_type=unet.config.class_embed_type,
|
||||||
|
addition_embed_type=addition_embed_type,
|
||||||
|
addition_time_embed_dim=addition_time_embed_dim,
|
||||||
|
num_class_embeds=unet.config.num_class_embeds,
|
||||||
|
upcast_attention=unet.config.upcast_attention,
|
||||||
|
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
||||||
|
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
||||||
|
brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
|
||||||
|
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if load_weights_from_unet:
|
||||||
|
conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
|
||||||
|
conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
|
||||||
|
conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
|
||||||
|
brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
|
||||||
|
brushnet.conv_in_condition.bias = unet.conv_in.bias
|
||||||
|
|
||||||
|
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||||
|
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
||||||
|
|
||||||
|
if brushnet.class_embedding:
|
||||||
|
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
||||||
|
|
||||||
|
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
|
||||||
|
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
|
||||||
|
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
|
||||||
|
|
||||||
|
return brushnet
|
||||||
|
|
||||||
|
@property
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||||
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||||
|
indexed by its weight name.
|
||||||
|
"""
|
||||||
|
# set recursively
|
||||||
|
processors = {}
|
||||||
|
|
||||||
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||||
|
if hasattr(module, "get_processor"):
|
||||||
|
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_add_processors(name, module, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||||
|
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||||
|
r"""
|
||||||
|
Sets the attention processor to use to compute attention.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||||
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||||
|
for **all** `Attention` layers.
|
||||||
|
|
||||||
|
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||||
|
processor. This is strongly recommended when setting trainable attention processors.
|
||||||
|
|
||||||
|
"""
|
||||||
|
count = len(self.attn_processors.keys())
|
||||||
|
|
||||||
|
if isinstance(processor, dict) and len(processor) != count:
|
||||||
|
raise ValueError(
|
||||||
|
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||||
|
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||||
|
)
|
||||||
|
|
||||||
|
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||||
|
if hasattr(module, "set_processor"):
|
||||||
|
if not isinstance(processor, dict):
|
||||||
|
module.set_processor(processor)
|
||||||
|
else:
|
||||||
|
module.set_processor(processor.pop(f"{name}.processor"))
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_attn_processor(name, module, processor)
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||||
|
def set_default_attn_processor(self):
|
||||||
|
"""
|
||||||
|
Disables custom attention processors and sets the default attention implementation.
|
||||||
|
"""
|
||||||
|
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||||
|
processor = AttnAddedKVProcessor()
|
||||||
|
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||||
|
processor = AttnProcessor()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_attn_processor(processor)
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||||
|
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||||
|
r"""
|
||||||
|
Enable sliced attention computation.
|
||||||
|
|
||||||
|
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
||||||
|
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||||
|
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
||||||
|
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
||||||
|
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||||
|
must be a multiple of `slice_size`.
|
||||||
|
"""
|
||||||
|
sliceable_head_dims = []
|
||||||
|
|
||||||
|
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
||||||
|
if hasattr(module, "set_attention_slice"):
|
||||||
|
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_retrieve_sliceable_dims(child)
|
||||||
|
|
||||||
|
# retrieve number of attention layers
|
||||||
|
for module in self.children():
|
||||||
|
fn_recursive_retrieve_sliceable_dims(module)
|
||||||
|
|
||||||
|
num_sliceable_layers = len(sliceable_head_dims)
|
||||||
|
|
||||||
|
if slice_size == "auto":
|
||||||
|
# half the attention head size is usually a good trade-off between
|
||||||
|
# speed and memory
|
||||||
|
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||||
|
elif slice_size == "max":
|
||||||
|
# make smallest slice possible
|
||||||
|
slice_size = num_sliceable_layers * [1]
|
||||||
|
|
||||||
|
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||||
|
|
||||||
|
if len(slice_size) != len(sliceable_head_dims):
|
||||||
|
raise ValueError(
|
||||||
|
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||||
|
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(len(slice_size)):
|
||||||
|
size = slice_size[i]
|
||||||
|
dim = sliceable_head_dims[i]
|
||||||
|
if size is not None and size > dim:
|
||||||
|
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||||
|
|
||||||
|
# Recursively walk through all the children.
|
||||||
|
# Any children which exposes the set_attention_slice method
|
||||||
|
# gets the message
|
||||||
|
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||||
|
if hasattr(module, "set_attention_slice"):
|
||||||
|
module.set_attention_slice(slice_size.pop())
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_set_attention_slice(child, slice_size)
|
||||||
|
|
||||||
|
reversed_slice_size = list(reversed(slice_size))
|
||||||
|
for module in self.children():
|
||||||
|
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||||
|
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
||||||
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
timestep: Union[torch.Tensor, float, int],
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
brushnet_cond: torch.FloatTensor,
|
||||||
|
conditioning_scale: float = 1.0,
|
||||||
|
class_labels: Optional[torch.Tensor] = None,
|
||||||
|
timestep_cond: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
guess_mode: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
|
) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
|
||||||
|
"""
|
||||||
|
The [`BrushNetModel`] forward method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`):
|
||||||
|
The noisy input tensor.
|
||||||
|
timestep (`Union[torch.Tensor, float, int]`):
|
||||||
|
The number of timesteps to denoise an input.
|
||||||
|
encoder_hidden_states (`torch.Tensor`):
|
||||||
|
The encoder hidden states.
|
||||||
|
brushnet_cond (`torch.FloatTensor`):
|
||||||
|
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||||
|
conditioning_scale (`float`, defaults to `1.0`):
|
||||||
|
The scale factor for BrushNet outputs.
|
||||||
|
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||||
|
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
||||||
|
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
||||||
|
embeddings.
|
||||||
|
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||||
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||||
|
negative values to the attention scores corresponding to "discard" tokens.
|
||||||
|
added_cond_kwargs (`dict`):
|
||||||
|
Additional conditions for the Stable Diffusion XL UNet.
|
||||||
|
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||||
|
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||||
|
guess_mode (`bool`, defaults to `False`):
|
||||||
|
In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
|
||||||
|
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||||
|
return_dict (`bool`, defaults to `True`):
|
||||||
|
Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.brushnet.BrushNetOutput`] **or** `tuple`:
|
||||||
|
If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
|
||||||
|
returned where the first element is the sample tensor.
|
||||||
|
"""
|
||||||
|
# check channel order
|
||||||
|
channel_order = self.config.brushnet_conditioning_channel_order
|
||||||
|
|
||||||
|
if channel_order == "rgb":
|
||||||
|
# in rgb order by default
|
||||||
|
...
|
||||||
|
elif channel_order == "bgr":
|
||||||
|
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
|
||||||
|
|
||||||
|
# prepare attention_mask
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||||
|
attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# 1. time
|
||||||
|
timesteps = timestep
|
||||||
|
if not torch.is_tensor(timesteps):
|
||||||
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||||
|
# This would be a good case for the `match` statement (Python 3.10+)
|
||||||
|
is_mps = sample.device.type == "mps"
|
||||||
|
if isinstance(timestep, float):
|
||||||
|
dtype = torch.float32 if is_mps else torch.float64
|
||||||
|
else:
|
||||||
|
dtype = torch.int32 if is_mps else torch.int64
|
||||||
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||||
|
elif len(timesteps.shape) == 0:
|
||||||
|
timesteps = timesteps[None].to(sample.device)
|
||||||
|
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timesteps = timesteps.expand(sample.shape[0])
|
||||||
|
|
||||||
|
t_emb = self.time_proj(timesteps)
|
||||||
|
|
||||||
|
# timesteps does not contain any weights and will always return f32 tensors
|
||||||
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||||
|
# there might be better ways to encapsulate this.
|
||||||
|
t_emb = t_emb.to(dtype=sample.dtype)
|
||||||
|
|
||||||
|
emb = self.time_embedding(t_emb, timestep_cond)
|
||||||
|
aug_emb = None
|
||||||
|
|
||||||
|
if self.class_embedding is not None:
|
||||||
|
if class_labels is None:
|
||||||
|
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||||
|
|
||||||
|
if self.config.class_embed_type == "timestep":
|
||||||
|
class_labels = self.time_proj(class_labels)
|
||||||
|
|
||||||
|
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||||
|
emb = emb + class_emb
|
||||||
|
|
||||||
|
if self.config.addition_embed_type is not None:
|
||||||
|
if self.config.addition_embed_type == "text":
|
||||||
|
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||||
|
|
||||||
|
elif self.config.addition_embed_type == "text_time":
|
||||||
|
if "text_embeds" not in added_cond_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||||
|
)
|
||||||
|
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||||
|
if "time_ids" not in added_cond_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||||
|
)
|
||||||
|
time_ids = added_cond_kwargs.get("time_ids")
|
||||||
|
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||||
|
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||||
|
|
||||||
|
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||||
|
add_embeds = add_embeds.to(emb.dtype)
|
||||||
|
aug_emb = self.add_embedding(add_embeds)
|
||||||
|
|
||||||
|
emb = emb + aug_emb if aug_emb is not None else emb
|
||||||
|
|
||||||
|
# 2. pre-process
|
||||||
|
brushnet_cond = torch.concat([sample, brushnet_cond], 1)
|
||||||
|
sample = self.conv_in_condition(brushnet_cond)
|
||||||
|
|
||||||
|
# 3. down
|
||||||
|
down_block_res_samples = (sample,)
|
||||||
|
for downsample_block in self.down_blocks:
|
||||||
|
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||||
|
sample, res_samples = downsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||||
|
|
||||||
|
down_block_res_samples += res_samples
|
||||||
|
|
||||||
|
# 4. PaintingNet down blocks
|
||||||
|
brushnet_down_block_res_samples = ()
|
||||||
|
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
|
||||||
|
down_block_res_sample = brushnet_down_block(down_block_res_sample)
|
||||||
|
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
|
||||||
|
|
||||||
|
# 5. mid
|
||||||
|
if self.mid_block is not None:
|
||||||
|
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||||
|
sample = self.mid_block(
|
||||||
|
sample,
|
||||||
|
emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample = self.mid_block(sample, emb)
|
||||||
|
|
||||||
|
# 6. BrushNet mid blocks
|
||||||
|
brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
|
||||||
|
|
||||||
|
# 7. up
|
||||||
|
up_block_res_samples = ()
|
||||||
|
for i, upsample_block in enumerate(self.up_blocks):
|
||||||
|
is_final_block = i == len(self.up_blocks) - 1
|
||||||
|
|
||||||
|
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
||||||
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||||
|
|
||||||
|
# if we have not reached the final block and need to forward the
|
||||||
|
# upsample size, we do it here
|
||||||
|
if not is_final_block:
|
||||||
|
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||||
|
|
||||||
|
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||||
|
sample, up_res_samples = upsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
return_res_samples=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample, up_res_samples = upsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
return_res_samples=True
|
||||||
|
)
|
||||||
|
|
||||||
|
up_block_res_samples += up_res_samples
|
||||||
|
|
||||||
|
# 8. BrushNet up blocks
|
||||||
|
brushnet_up_block_res_samples = ()
|
||||||
|
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
|
||||||
|
up_block_res_sample = brushnet_up_block(up_block_res_sample)
|
||||||
|
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
|
||||||
|
|
||||||
|
# 6. scaling
|
||||||
|
if guess_mode and not self.config.global_pool_conditions:
|
||||||
|
scales = torch.logspace(-1, 0,
|
||||||
|
len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
|
||||||
|
device=sample.device) # 0.1 to 1.0
|
||||||
|
scales = scales * conditioning_scale
|
||||||
|
|
||||||
|
brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples,
|
||||||
|
scales[:len(
|
||||||
|
brushnet_down_block_res_samples)])]
|
||||||
|
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
|
||||||
|
brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples,
|
||||||
|
scales[
|
||||||
|
len(brushnet_down_block_res_samples) + 1:])]
|
||||||
|
else:
|
||||||
|
brushnet_down_block_res_samples = [sample * conditioning_scale for sample in
|
||||||
|
brushnet_down_block_res_samples]
|
||||||
|
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
|
||||||
|
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
|
||||||
|
|
||||||
|
if self.config.global_pool_conditions:
|
||||||
|
brushnet_down_block_res_samples = [
|
||||||
|
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
|
||||||
|
]
|
||||||
|
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||||
|
brushnet_up_block_res_samples = [
|
||||||
|
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
|
||||||
|
]
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
|
||||||
|
|
||||||
|
return BrushNetOutput(
|
||||||
|
down_block_res_samples=brushnet_down_block_res_samples,
|
||||||
|
mid_block_res_sample=brushnet_mid_block_res_sample,
|
||||||
|
up_block_res_samples=brushnet_up_block_res_samples
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def zero_module(module):
|
||||||
|
for p in module.parameters():
|
||||||
|
nn.init.zeros_(p)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
BrushNetModel.from_pretrained("/Users/cwq/data/models/brushnet/brushnet_random_mask", variant='fp16',
|
||||||
|
use_safetensors=True)
|
322
iopaint/model/brushnet/brushnet_unet_forward.py
Normal file
322
iopaint/model/brushnet/brushnet_unet_forward.py
Normal file
@ -0,0 +1,322 @@
|
|||||||
|
from typing import Union, Optional, Dict, Any, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||||
|
from diffusers.utils import USE_PEFT_BACKEND, unscale_lora_layers, deprecate, scale_lora_layers
|
||||||
|
|
||||||
|
|
||||||
|
def brushnet_unet_forward(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
timestep: Union[torch.Tensor, float, int],
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
class_labels: Optional[torch.Tensor] = None,
|
||||||
|
timestep_cond: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||||
|
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||||
|
r"""
|
||||||
|
The [`UNet2DConditionModel`] forward method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`):
|
||||||
|
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
||||||
|
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||||
|
encoder_hidden_states (`torch.FloatTensor`):
|
||||||
|
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
||||||
|
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||||
|
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
||||||
|
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
||||||
|
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||||
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||||
|
negative values to the attention scores corresponding to "discard" tokens.
|
||||||
|
cross_attention_kwargs (`dict`, *optional*):
|
||||||
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||||
|
`self.processor` in
|
||||||
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||||
|
added_cond_kwargs: (`dict`, *optional*):
|
||||||
|
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
||||||
|
are passed along to the UNet blocks.
|
||||||
|
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
||||||
|
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
||||||
|
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
||||||
|
A tensor that if specified is added to the residual of the middle unet block.
|
||||||
|
encoder_attention_mask (`torch.Tensor`):
|
||||||
|
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
||||||
|
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
||||||
|
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||||
|
tuple.
|
||||||
|
cross_attention_kwargs (`dict`, *optional*):
|
||||||
|
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
||||||
|
added_cond_kwargs: (`dict`, *optional*):
|
||||||
|
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
||||||
|
are passed along to the UNet blocks.
|
||||||
|
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||||
|
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
|
||||||
|
example from ControlNet side model(s)
|
||||||
|
mid_block_additional_residual (`torch.Tensor`, *optional*):
|
||||||
|
additional residual to be added to UNet mid block output, for example from ControlNet side model
|
||||||
|
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||||
|
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||||
|
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
||||||
|
a `tuple` is returned where the first element is the sample tensor.
|
||||||
|
"""
|
||||||
|
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||||
|
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
||||||
|
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||||
|
# on the fly if necessary.
|
||||||
|
default_overall_up_factor = 2 ** self.num_upsamplers
|
||||||
|
|
||||||
|
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||||
|
forward_upsample_size = False
|
||||||
|
upsample_size = None
|
||||||
|
|
||||||
|
for dim in sample.shape[-2:]:
|
||||||
|
if dim % default_overall_up_factor != 0:
|
||||||
|
# Forward upsample size to force interpolation output size.
|
||||||
|
forward_upsample_size = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
||||||
|
# expects mask of shape:
|
||||||
|
# [batch, key_tokens]
|
||||||
|
# adds singleton query_tokens dimension:
|
||||||
|
# [batch, 1, key_tokens]
|
||||||
|
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||||
|
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||||
|
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||||
|
if attention_mask is not None:
|
||||||
|
# assume that mask is expressed as:
|
||||||
|
# (1 = keep, 0 = discard)
|
||||||
|
# convert mask into a bias that can be added to attention scores:
|
||||||
|
# (keep = +0, discard = -10000.0)
|
||||||
|
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||||
|
attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||||
|
if encoder_attention_mask is not None:
|
||||||
|
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
||||||
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# 0. center input if necessary
|
||||||
|
if self.config.center_input_sample:
|
||||||
|
sample = 2 * sample - 1.0
|
||||||
|
|
||||||
|
# 1. time
|
||||||
|
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
|
||||||
|
emb = self.time_embedding(t_emb, timestep_cond)
|
||||||
|
aug_emb = None
|
||||||
|
|
||||||
|
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
|
||||||
|
if class_emb is not None:
|
||||||
|
if self.config.class_embeddings_concat:
|
||||||
|
emb = torch.cat([emb, class_emb], dim=-1)
|
||||||
|
else:
|
||||||
|
emb = emb + class_emb
|
||||||
|
|
||||||
|
aug_emb = self.get_aug_embed(
|
||||||
|
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
||||||
|
)
|
||||||
|
if self.config.addition_embed_type == "image_hint":
|
||||||
|
aug_emb, hint = aug_emb
|
||||||
|
sample = torch.cat([sample, hint], dim=1)
|
||||||
|
|
||||||
|
emb = emb + aug_emb if aug_emb is not None else emb
|
||||||
|
|
||||||
|
if self.time_embed_act is not None:
|
||||||
|
emb = self.time_embed_act(emb)
|
||||||
|
|
||||||
|
encoder_hidden_states = self.process_encoder_hidden_states(
|
||||||
|
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. pre-process
|
||||||
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
|
# 2.5 GLIGEN position net
|
||||||
|
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
||||||
|
cross_attention_kwargs = cross_attention_kwargs.copy()
|
||||||
|
gligen_args = cross_attention_kwargs.pop("gligen")
|
||||||
|
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
||||||
|
|
||||||
|
# 3. down
|
||||||
|
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||||
|
if USE_PEFT_BACKEND:
|
||||||
|
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||||
|
scale_lora_layers(self, lora_scale)
|
||||||
|
|
||||||
|
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
||||||
|
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
||||||
|
is_adapter = down_intrablock_additional_residuals is not None
|
||||||
|
# maintain backward compatibility for legacy usage, where
|
||||||
|
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
||||||
|
# but can only use one or the other
|
||||||
|
is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None
|
||||||
|
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
|
||||||
|
deprecate(
|
||||||
|
"T2I should not use down_block_additional_residuals",
|
||||||
|
"1.3.0",
|
||||||
|
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
||||||
|
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
||||||
|
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
||||||
|
standard_warn=False,
|
||||||
|
)
|
||||||
|
down_intrablock_additional_residuals = down_block_additional_residuals
|
||||||
|
is_adapter = True
|
||||||
|
|
||||||
|
down_block_res_samples = (sample,)
|
||||||
|
|
||||||
|
if is_brushnet:
|
||||||
|
sample = sample + down_block_add_samples.pop(0)
|
||||||
|
|
||||||
|
for downsample_block in self.down_blocks:
|
||||||
|
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||||
|
# For t2i-adapter CrossAttnDownBlock2D
|
||||||
|
additional_residuals = {}
|
||||||
|
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||||
|
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
||||||
|
|
||||||
|
if is_brushnet and len(down_block_add_samples) > 0:
|
||||||
|
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
|
||||||
|
for _ in range(
|
||||||
|
len(downsample_block.resnets) + (downsample_block.downsamplers != None))]
|
||||||
|
|
||||||
|
sample, res_samples = downsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
**additional_residuals,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
additional_residuals = {}
|
||||||
|
if is_brushnet and len(down_block_add_samples) > 0:
|
||||||
|
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
|
||||||
|
for _ in range(
|
||||||
|
len(downsample_block.resnets) + (downsample_block.downsamplers != None))]
|
||||||
|
|
||||||
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale,
|
||||||
|
**additional_residuals)
|
||||||
|
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||||
|
sample += down_intrablock_additional_residuals.pop(0)
|
||||||
|
|
||||||
|
down_block_res_samples += res_samples
|
||||||
|
|
||||||
|
if is_controlnet:
|
||||||
|
new_down_block_res_samples = ()
|
||||||
|
|
||||||
|
for down_block_res_sample, down_block_additional_residual in zip(
|
||||||
|
down_block_res_samples, down_block_additional_residuals
|
||||||
|
):
|
||||||
|
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
||||||
|
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
||||||
|
|
||||||
|
down_block_res_samples = new_down_block_res_samples
|
||||||
|
|
||||||
|
# 4. mid
|
||||||
|
if self.mid_block is not None:
|
||||||
|
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||||
|
sample = self.mid_block(
|
||||||
|
sample,
|
||||||
|
emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample = self.mid_block(sample, emb)
|
||||||
|
|
||||||
|
# To support T2I-Adapter-XL
|
||||||
|
if (
|
||||||
|
is_adapter
|
||||||
|
and len(down_intrablock_additional_residuals) > 0
|
||||||
|
and sample.shape == down_intrablock_additional_residuals[0].shape
|
||||||
|
):
|
||||||
|
sample += down_intrablock_additional_residuals.pop(0)
|
||||||
|
|
||||||
|
if is_controlnet:
|
||||||
|
sample = sample + mid_block_additional_residual
|
||||||
|
|
||||||
|
if is_brushnet:
|
||||||
|
sample = sample + mid_block_add_sample
|
||||||
|
|
||||||
|
# 5. up
|
||||||
|
for i, upsample_block in enumerate(self.up_blocks):
|
||||||
|
is_final_block = i == len(self.up_blocks) - 1
|
||||||
|
|
||||||
|
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
||||||
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||||
|
|
||||||
|
# if we have not reached the final block and need to forward the
|
||||||
|
# upsample size, we do it here
|
||||||
|
if not is_final_block and forward_upsample_size:
|
||||||
|
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||||
|
|
||||||
|
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||||
|
additional_residuals = {}
|
||||||
|
if is_brushnet and len(up_block_add_samples) > 0:
|
||||||
|
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
|
||||||
|
for _ in range(
|
||||||
|
len(upsample_block.resnets) + (upsample_block.upsamplers != None))]
|
||||||
|
|
||||||
|
sample = upsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
**additional_residuals,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
additional_residuals = {}
|
||||||
|
if is_brushnet and len(up_block_add_samples) > 0:
|
||||||
|
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
|
||||||
|
for _ in range(
|
||||||
|
len(upsample_block.resnets) + (upsample_block.upsamplers != None))]
|
||||||
|
|
||||||
|
sample = upsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
scale=lora_scale,
|
||||||
|
**additional_residuals,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. post-process
|
||||||
|
if self.conv_norm_out:
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
sample = self.conv_out(sample)
|
||||||
|
|
||||||
|
if USE_PEFT_BACKEND:
|
||||||
|
# remove `lora_scale` from each PEFT layer
|
||||||
|
unscale_lora_layers(self, lora_scale)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (sample,)
|
||||||
|
|
||||||
|
return UNet2DConditionOutput(sample=sample)
|
157
iopaint/model/brushnet/brushnet_wrapper.py
Normal file
157
iopaint/model/brushnet/brushnet_wrapper.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
import PIL.Image
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ..base import DiffusionInpaintModel
|
||||||
|
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||||
|
from ..original_sd_configs import get_config_files
|
||||||
|
from ..utils import (
|
||||||
|
handle_from_pretrained_exceptions,
|
||||||
|
get_torch_dtype,
|
||||||
|
enable_low_mem,
|
||||||
|
is_local_files_only,
|
||||||
|
)
|
||||||
|
from .brushnet import BrushNetModel
|
||||||
|
from .brushnet_unet_forward import brushnet_unet_forward
|
||||||
|
from .unet_2d_blocks import CrossAttnDownBlock2D_forward, DownBlock2D_forward, CrossAttnUpBlock2D_forward, \
|
||||||
|
UpBlock2D_forward
|
||||||
|
from ...schema import InpaintRequest, ModelType
|
||||||
|
|
||||||
|
|
||||||
|
class BrushNetWrapper(DiffusionInpaintModel):
|
||||||
|
pad_mod = 8
|
||||||
|
min_size = 512
|
||||||
|
|
||||||
|
def init_model(self, device: torch.device, **kwargs):
|
||||||
|
from .pipeline_brushnet import StableDiffusionBrushNetPipeline
|
||||||
|
self.model_info = kwargs["model_info"]
|
||||||
|
self.brushnet_method = kwargs["brushnet_method"]
|
||||||
|
|
||||||
|
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
|
||||||
|
model_kwargs = {
|
||||||
|
**kwargs.get("pipe_components", {}),
|
||||||
|
"local_files_only": is_local_files_only(**kwargs),
|
||||||
|
}
|
||||||
|
self.local_files_only = model_kwargs["local_files_only"]
|
||||||
|
|
||||||
|
disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
|
||||||
|
"cpu_offload", False
|
||||||
|
)
|
||||||
|
if disable_nsfw_checker:
|
||||||
|
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||||
|
model_kwargs.update(
|
||||||
|
dict(
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Loading BrushNet model from {self.brushnet_method}")
|
||||||
|
brushnet = BrushNetModel.from_pretrained(self.brushnet_method, torch_dtype=torch_dtype)
|
||||||
|
|
||||||
|
if self.model_info.is_single_file_diffusers:
|
||||||
|
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||||
|
model_kwargs["num_in_channels"] = 4
|
||||||
|
else:
|
||||||
|
model_kwargs["num_in_channels"] = 9
|
||||||
|
|
||||||
|
self.model = StableDiffusionBrushNetPipeline.from_single_file(
|
||||||
|
self.model_id_or_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
load_safety_checker=not disable_nsfw_checker,
|
||||||
|
config_files=get_config_files(),
|
||||||
|
brushnet=brushnet,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model = handle_from_pretrained_exceptions(
|
||||||
|
StableDiffusionBrushNetPipeline.from_pretrained,
|
||||||
|
pretrained_model_name_or_path=self.model_id_or_path,
|
||||||
|
variant="fp16",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
brushnet=brushnet,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
enable_low_mem(self.model, kwargs.get("low_mem", False))
|
||||||
|
|
||||||
|
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||||
|
logger.info("Enable sequential cpu offload")
|
||||||
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
|
else:
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
if kwargs["sd_cpu_textencoder"]:
|
||||||
|
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||||
|
self.model.text_encoder = CPUTextEncoderWrapper(
|
||||||
|
self.model.text_encoder, torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
|
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
|
||||||
|
self.model.unet.forward = brushnet_unet_forward.__get__(self.model.unet, self.model.unet.__class__)
|
||||||
|
|
||||||
|
for down_block in self.model.brushnet.down_blocks:
|
||||||
|
down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__)
|
||||||
|
for up_block in self.model.brushnet.up_blocks:
|
||||||
|
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
|
||||||
|
|
||||||
|
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
|
||||||
|
for down_block in self.model.unet.down_blocks:
|
||||||
|
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
|
||||||
|
down_block.forward = CrossAttnDownBlock2D_forward.__get__(down_block, down_block.__class__)
|
||||||
|
else:
|
||||||
|
down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__)
|
||||||
|
|
||||||
|
for up_block in self.model.unet.up_blocks:
|
||||||
|
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
||||||
|
up_block.forward = CrossAttnUpBlock2D_forward.__get__(up_block, up_block.__class__)
|
||||||
|
else:
|
||||||
|
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
|
||||||
|
|
||||||
|
def switch_brushnet_method(self, new_method: str):
|
||||||
|
self.brushnet_method = new_method
|
||||||
|
brushnet = BrushNetModel.from_pretrained(
|
||||||
|
new_method,
|
||||||
|
resume_download=True,
|
||||||
|
local_files_only=self.local_files_only,
|
||||||
|
torch_dtype=self.torch_dtype,
|
||||||
|
).to(self.model.device)
|
||||||
|
self.model.brushnet = brushnet
|
||||||
|
|
||||||
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
|
"""Input image and output image have same size
|
||||||
|
image: [H, W, C] RGB
|
||||||
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
|
return: BGR IMAGE
|
||||||
|
"""
|
||||||
|
self.set_scheduler(config)
|
||||||
|
|
||||||
|
img_h, img_w = image.shape[:2]
|
||||||
|
normalized_mask = mask[:, :].astype("float32") / 255.0
|
||||||
|
image = image * (1 - normalized_mask)
|
||||||
|
image = image.astype(np.uint8)
|
||||||
|
output = self.model(
|
||||||
|
image=PIL.Image.fromarray(image),
|
||||||
|
prompt=config.prompt,
|
||||||
|
negative_prompt=config.negative_prompt,
|
||||||
|
mask=PIL.Image.fromarray(mask[:, :, -1], mode="L").convert("RGB"),
|
||||||
|
num_inference_steps=config.sd_steps,
|
||||||
|
# strength=config.sd_strength,
|
||||||
|
guidance_scale=config.sd_guidance_scale,
|
||||||
|
output_type="np",
|
||||||
|
callback_on_step_end=self.callback,
|
||||||
|
height=img_h,
|
||||||
|
width=img_w,
|
||||||
|
generator=torch.manual_seed(config.sd_seed),
|
||||||
|
brushnet_conditioning_scale=config.brushnet_conditioning_scale,
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
output = (output * 255).round().astype("uint8")
|
||||||
|
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||||
|
return output
|
1279
iopaint/model/brushnet/pipeline_brushnet.py
Normal file
1279
iopaint/model/brushnet/pipeline_brushnet.py
Normal file
File diff suppressed because it is too large
Load Diff
388
iopaint/model/brushnet/unet_2d_blocks.py
Normal file
388
iopaint/model/brushnet/unet_2d_blocks.py
Normal file
@ -0,0 +1,388 @@
|
|||||||
|
from typing import Dict, Any, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers.models.resnet import ResnetBlock2D
|
||||||
|
from diffusers.utils import is_torch_version
|
||||||
|
from diffusers.utils.torch_utils import apply_freeu
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class MidBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
temb_channels: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
num_layers: int = 1,
|
||||||
|
resnet_eps: float = 1e-6,
|
||||||
|
resnet_time_scale_shift: str = "default",
|
||||||
|
resnet_act_fn: str = "swish",
|
||||||
|
resnet_groups: int = 32,
|
||||||
|
resnet_pre_norm: bool = True,
|
||||||
|
output_scale_factor: float = 1.0,
|
||||||
|
use_linear_projection: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.has_cross_attention = False
|
||||||
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||||
|
|
||||||
|
# there is always at least one resnet
|
||||||
|
resnets = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(num_layers):
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
lora_scale = 1.0
|
||||||
|
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
||||||
|
for resnet in self.resnets[1:]:
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module, return_dict=None):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
if return_dict is not None:
|
||||||
|
return module(*inputs, return_dict=return_dict)
|
||||||
|
else:
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet),
|
||||||
|
hidden_states,
|
||||||
|
temb,
|
||||||
|
**ckpt_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def DownBlock2D_forward(
|
||||||
|
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0,
|
||||||
|
down_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||||
|
output_states = ()
|
||||||
|
|
||||||
|
for resnet in self.resnets:
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
if is_torch_version(">=", "1.11.0"):
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet), hidden_states, temb
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||||
|
|
||||||
|
if down_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + down_block_add_samples.pop(0)
|
||||||
|
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states, scale=scale)
|
||||||
|
|
||||||
|
if down_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + down_block_add_samples.pop(0) # todo: add before or after
|
||||||
|
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
|
||||||
|
return hidden_states, output_states
|
||||||
|
|
||||||
|
|
||||||
|
def CrossAttnDownBlock2D_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
additional_residuals: Optional[torch.FloatTensor] = None,
|
||||||
|
down_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||||
|
output_states = ()
|
||||||
|
|
||||||
|
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||||
|
|
||||||
|
blocks = list(zip(self.resnets, self.attentions))
|
||||||
|
|
||||||
|
for i, (resnet, attn) in enumerate(blocks):
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module, return_dict=None):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
if return_dict is not None:
|
||||||
|
return module(*inputs, return_dict=return_dict)
|
||||||
|
else:
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet),
|
||||||
|
hidden_states,
|
||||||
|
temb,
|
||||||
|
**ckpt_kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = attn(
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||||
|
hidden_states = attn(
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
||||||
|
if i == len(blocks) - 1 and additional_residuals is not None:
|
||||||
|
hidden_states = hidden_states + additional_residuals
|
||||||
|
|
||||||
|
if down_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + down_block_add_samples.pop(0)
|
||||||
|
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states, scale=lora_scale)
|
||||||
|
|
||||||
|
if down_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + down_block_add_samples.pop(0) # todo: add before or after
|
||||||
|
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
|
||||||
|
return hidden_states, output_states
|
||||||
|
|
||||||
|
|
||||||
|
def CrossAttnUpBlock2D_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||||
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
upsample_size: Optional[int] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
return_res_samples: Optional[bool] = False,
|
||||||
|
up_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||||
|
is_freeu_enabled = (
|
||||||
|
getattr(self, "s1", None)
|
||||||
|
and getattr(self, "s2", None)
|
||||||
|
and getattr(self, "b1", None)
|
||||||
|
and getattr(self, "b2", None)
|
||||||
|
)
|
||||||
|
if return_res_samples:
|
||||||
|
output_states = ()
|
||||||
|
|
||||||
|
for resnet, attn in zip(self.resnets, self.attentions):
|
||||||
|
# pop res hidden states
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
|
||||||
|
# FreeU: Only operate on the first two stages
|
||||||
|
if is_freeu_enabled:
|
||||||
|
hidden_states, res_hidden_states = apply_freeu(
|
||||||
|
self.resolution_idx,
|
||||||
|
hidden_states,
|
||||||
|
res_hidden_states,
|
||||||
|
s1=self.s1,
|
||||||
|
s2=self.s2,
|
||||||
|
b1=self.b1,
|
||||||
|
b2=self.b2,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module, return_dict=None):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
if return_dict is not None:
|
||||||
|
return module(*inputs, return_dict=return_dict)
|
||||||
|
else:
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet),
|
||||||
|
hidden_states,
|
||||||
|
temb,
|
||||||
|
**ckpt_kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = attn(
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||||
|
hidden_states = attn(
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
if return_res_samples:
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
if up_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + up_block_add_samples.pop(0)
|
||||||
|
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
|
||||||
|
if return_res_samples:
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
if up_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + up_block_add_samples.pop(0)
|
||||||
|
|
||||||
|
if return_res_samples:
|
||||||
|
return hidden_states, output_states
|
||||||
|
else:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def UpBlock2D_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||||
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
upsample_size: Optional[int] = None,
|
||||||
|
scale: float = 1.0,
|
||||||
|
return_res_samples: Optional[bool] = False,
|
||||||
|
up_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
is_freeu_enabled = (
|
||||||
|
getattr(self, "s1", None)
|
||||||
|
and getattr(self, "s2", None)
|
||||||
|
and getattr(self, "b1", None)
|
||||||
|
and getattr(self, "b2", None)
|
||||||
|
)
|
||||||
|
if return_res_samples:
|
||||||
|
output_states = ()
|
||||||
|
|
||||||
|
for resnet in self.resnets:
|
||||||
|
# pop res hidden states
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
|
||||||
|
# FreeU: Only operate on the first two stages
|
||||||
|
if is_freeu_enabled:
|
||||||
|
hidden_states, res_hidden_states = apply_freeu(
|
||||||
|
self.resolution_idx,
|
||||||
|
hidden_states,
|
||||||
|
res_hidden_states,
|
||||||
|
s1=self.s1,
|
||||||
|
s2=self.s2,
|
||||||
|
b1=self.b1,
|
||||||
|
b2=self.b2,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
if is_torch_version(">=", "1.11.0"):
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet), hidden_states, temb
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||||
|
|
||||||
|
if return_res_samples:
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
if up_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + up_block_add_samples.pop(0) # todo: add before or after
|
||||||
|
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
|
||||||
|
|
||||||
|
if return_res_samples:
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
if up_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + up_block_add_samples.pop(0) # todo: add before or after
|
||||||
|
|
||||||
|
if return_res_samples:
|
||||||
|
return hidden_states, output_states
|
||||||
|
else:
|
||||||
|
return hidden_states
|
@ -7,6 +7,7 @@ import numpy as np
|
|||||||
from iopaint.download import scan_models
|
from iopaint.download import scan_models
|
||||||
from iopaint.helper import switch_mps_device
|
from iopaint.helper import switch_mps_device
|
||||||
from iopaint.model import models, ControlNet, SD, SDXL
|
from iopaint.model import models, ControlNet, SD, SDXL
|
||||||
|
from iopaint.model.brushnet.brushnet_wrapper import BrushNetWrapper
|
||||||
from iopaint.model.utils import torch_gc, is_local_files_only
|
from iopaint.model.utils import torch_gc, is_local_files_only
|
||||||
from iopaint.schema import InpaintRequest, ModelInfo, ModelType
|
from iopaint.schema import InpaintRequest, ModelInfo, ModelType
|
||||||
|
|
||||||
@ -22,12 +23,16 @@ class ModelManager:
|
|||||||
self.enable_controlnet = kwargs.get("enable_controlnet", False)
|
self.enable_controlnet = kwargs.get("enable_controlnet", False)
|
||||||
controlnet_method = kwargs.get("controlnet_method", None)
|
controlnet_method = kwargs.get("controlnet_method", None)
|
||||||
if (
|
if (
|
||||||
controlnet_method is None
|
controlnet_method is None
|
||||||
and name in self.available_models
|
and name in self.available_models
|
||||||
and self.available_models[name].support_controlnet
|
and self.available_models[name].support_controlnet
|
||||||
):
|
):
|
||||||
controlnet_method = self.available_models[name].controlnets[0]
|
controlnet_method = self.available_models[name].controlnets[0]
|
||||||
self.controlnet_method = controlnet_method
|
self.controlnet_method = controlnet_method
|
||||||
|
|
||||||
|
self.enable_brushnet = kwargs.get("enable_brushnet", False)
|
||||||
|
self.brushnet_method = kwargs.get("brushnet_method", None)
|
||||||
|
|
||||||
self.model = self.init_model(name, device, **kwargs)
|
self.model = self.init_model(name, device, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -47,24 +52,30 @@ class ModelManager:
|
|||||||
"model_info": model_info,
|
"model_info": model_info,
|
||||||
"enable_controlnet": self.enable_controlnet,
|
"enable_controlnet": self.enable_controlnet,
|
||||||
"controlnet_method": self.controlnet_method,
|
"controlnet_method": self.controlnet_method,
|
||||||
|
"enable_brushnet": self.enable_brushnet,
|
||||||
|
"brushnet_method": self.brushnet_method,
|
||||||
}
|
}
|
||||||
|
|
||||||
if model_info.support_controlnet and self.enable_controlnet:
|
if model_info.support_controlnet and self.enable_controlnet:
|
||||||
return ControlNet(device, **kwargs)
|
return ControlNet(device, **kwargs)
|
||||||
elif model_info.name in models:
|
|
||||||
return models[name](device, **kwargs)
|
|
||||||
else:
|
|
||||||
if model_info.model_type in [
|
|
||||||
ModelType.DIFFUSERS_SD_INPAINT,
|
|
||||||
ModelType.DIFFUSERS_SD,
|
|
||||||
]:
|
|
||||||
return SD(device, **kwargs)
|
|
||||||
|
|
||||||
if model_info.model_type in [
|
if model_info.support_brushnet and self.enable_brushnet:
|
||||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
return BrushNetWrapper(device, **kwargs)
|
||||||
ModelType.DIFFUSERS_SDXL,
|
|
||||||
]:
|
if model_info.name in models:
|
||||||
return SDXL(device, **kwargs)
|
return models[name](device, **kwargs)
|
||||||
|
|
||||||
|
if model_info.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
]:
|
||||||
|
return SD(device, **kwargs)
|
||||||
|
|
||||||
|
if model_info.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
|
ModelType.DIFFUSERS_SDXL,
|
||||||
|
]:
|
||||||
|
return SDXL(device, **kwargs)
|
||||||
|
|
||||||
raise NotImplementedError(f"Unsupported model: {name}")
|
raise NotImplementedError(f"Unsupported model: {name}")
|
||||||
|
|
||||||
@ -80,7 +91,10 @@ class ModelManager:
|
|||||||
Returns:
|
Returns:
|
||||||
BGR image
|
BGR image
|
||||||
"""
|
"""
|
||||||
self.switch_controlnet_method(config)
|
if not config.enable_brushnet:
|
||||||
|
self.switch_controlnet_method(config)
|
||||||
|
if not config.enable_controlnet:
|
||||||
|
self.switch_brushnet_method(config)
|
||||||
self.enable_disable_freeu(config)
|
self.enable_disable_freeu(config)
|
||||||
self.enable_disable_lcm_lora(config)
|
self.enable_disable_lcm_lora(config)
|
||||||
return self.model(image, mask, config).astype(np.uint8)
|
return self.model(image, mask, config).astype(np.uint8)
|
||||||
@ -99,9 +113,9 @@ class ModelManager:
|
|||||||
self.name = new_name
|
self.name = new_name
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.available_models[new_name].support_controlnet
|
self.available_models[new_name].support_controlnet
|
||||||
and self.controlnet_method
|
and self.controlnet_method
|
||||||
not in self.available_models[new_name].controlnets
|
not in self.available_models[new_name].controlnets
|
||||||
):
|
):
|
||||||
self.controlnet_method = self.available_models[new_name].controlnets[0]
|
self.controlnet_method = self.available_models[new_name].controlnets[0]
|
||||||
try:
|
try:
|
||||||
@ -121,14 +135,54 @@ class ModelManager:
|
|||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def switch_brushnet_method(self, config):
|
||||||
|
if not self.available_models[self.name].support_brushnet:
|
||||||
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.enable_brushnet
|
||||||
|
and config.brushnet_method
|
||||||
|
and self.brushnet_method != config.brushnet_method
|
||||||
|
):
|
||||||
|
old_brushnet_method = self.brushnet_method
|
||||||
|
self.brushnet_method = config.brushnet_method
|
||||||
|
self.model.switch_brushnet_method(config.brushnet_method)
|
||||||
|
logger.info(
|
||||||
|
f"Switch Brushnet method from {old_brushnet_method} to {config.brushnet_method}"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.enable_brushnet != config.enable_brushnet:
|
||||||
|
self.enable_brushnet = config.enable_brushnet
|
||||||
|
self.brushnet_method = config.brushnet_method
|
||||||
|
|
||||||
|
pipe_components = {
|
||||||
|
"vae": self.model.model.vae,
|
||||||
|
"text_encoder": self.model.model.text_encoder,
|
||||||
|
"unet": self.model.model.unet,
|
||||||
|
}
|
||||||
|
if hasattr(self.model.model, "text_encoder_2"):
|
||||||
|
pipe_components["text_encoder_2"] = self.model.model.text_encoder_2
|
||||||
|
|
||||||
|
self.model = self.init_model(
|
||||||
|
self.name,
|
||||||
|
switch_mps_device(self.name, self.device),
|
||||||
|
pipe_components=pipe_components,
|
||||||
|
**self.kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not config.enable_brushnet:
|
||||||
|
logger.info("BrushNet Disabled")
|
||||||
|
else:
|
||||||
|
logger.info("BrushNet Enabled")
|
||||||
|
|
||||||
def switch_controlnet_method(self, config):
|
def switch_controlnet_method(self, config):
|
||||||
if not self.available_models[self.name].support_controlnet:
|
if not self.available_models[self.name].support_controlnet:
|
||||||
return
|
return
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.enable_controlnet
|
self.enable_controlnet
|
||||||
and config.controlnet_method
|
and config.controlnet_method
|
||||||
and self.controlnet_method != config.controlnet_method
|
and self.controlnet_method != config.controlnet_method
|
||||||
):
|
):
|
||||||
old_controlnet_method = self.controlnet_method
|
old_controlnet_method = self.controlnet_method
|
||||||
self.controlnet_method = config.controlnet_method
|
self.controlnet_method = config.controlnet_method
|
||||||
@ -155,7 +209,7 @@ class ModelManager:
|
|||||||
**self.kwargs,
|
**self.kwargs,
|
||||||
)
|
)
|
||||||
if not config.enable_controlnet:
|
if not config.enable_controlnet:
|
||||||
logger.info(f"Disable controlnet")
|
logger.info("Disable controlnet")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Enable controlnet: {config.controlnet_method}")
|
logger.info(f"Enable controlnet: {config.controlnet_method}")
|
||||||
|
|
||||||
|
@ -3,6 +3,8 @@ from enum import Enum
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Literal, List
|
from typing import Optional, Literal, List
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from iopaint.const import (
|
from iopaint.const import (
|
||||||
INSTRUCT_PIX2PIX_NAME,
|
INSTRUCT_PIX2PIX_NAME,
|
||||||
KANDINSKY22_NAME,
|
KANDINSKY22_NAME,
|
||||||
@ -11,9 +13,9 @@ from iopaint.const import (
|
|||||||
SDXL_CONTROLNET_CHOICES,
|
SDXL_CONTROLNET_CHOICES,
|
||||||
SD2_CONTROLNET_CHOICES,
|
SD2_CONTROLNET_CHOICES,
|
||||||
SD_CONTROLNET_CHOICES,
|
SD_CONTROLNET_CHOICES,
|
||||||
|
SD_BRUSHNET_CHOICES,
|
||||||
)
|
)
|
||||||
from loguru import logger
|
from pydantic import BaseModel, Field, computed_field, model_validator
|
||||||
from pydantic import BaseModel, Field, field_validator, computed_field
|
|
||||||
|
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
class ModelType(str, Enum):
|
||||||
@ -63,6 +65,13 @@ class ModelInfo(BaseModel):
|
|||||||
return SD_CONTROLNET_CHOICES
|
return SD_CONTROLNET_CHOICES
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def brushnets(self) -> List[str]:
|
||||||
|
if self.model_type in [ModelType.DIFFUSERS_SD]:
|
||||||
|
return SD_BRUSHNET_CHOICES
|
||||||
|
return []
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def support_strength(self) -> bool:
|
def support_strength(self) -> bool:
|
||||||
@ -103,6 +112,13 @@ class ModelInfo(BaseModel):
|
|||||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def support_brushnet(self) -> bool:
|
||||||
|
return self.model_type in [
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
]
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def support_freeu(self) -> bool:
|
def support_freeu(self) -> bool:
|
||||||
@ -369,6 +385,13 @@ class InpaintRequest(BaseModel):
|
|||||||
"lllyasviel/control_v11p_sd15_canny", description="Controlnet method"
|
"lllyasviel/control_v11p_sd15_canny", description="Controlnet method"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# BrushNet
|
||||||
|
enable_brushnet: bool = Field(False, description="Enable brushnet")
|
||||||
|
brushnet_method: str = Field(
|
||||||
|
SD_BRUSHNET_CHOICES[0], description="Brushnet method"
|
||||||
|
)
|
||||||
|
brushnet_conditioning_scale: float = Field(1.0, description="brushnet conditioning scale", ge=0.0, le=1.0)
|
||||||
|
|
||||||
# PowerPaint
|
# PowerPaint
|
||||||
powerpaint_task: PowerPaintTask = Field(
|
powerpaint_task: PowerPaintTask = Field(
|
||||||
PowerPaintTask.text_guided, description="PowerPaint task"
|
PowerPaintTask.text_guided, description="PowerPaint task"
|
||||||
@ -380,31 +403,37 @@ class InpaintRequest(BaseModel):
|
|||||||
le=1.0,
|
le=1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("sd_seed")
|
@model_validator(mode='after')
|
||||||
@classmethod
|
def validate_field(cls, values: 'InpaintRequest'):
|
||||||
def sd_seed_validator(cls, v: int) -> int:
|
if values.sd_seed == -1:
|
||||||
if v == -1:
|
values.sd_seed = random.randint(1, 99999999)
|
||||||
return random.randint(1, 99999999)
|
logger.info(f"Generate random seed: {values.sd_seed}")
|
||||||
return v
|
|
||||||
|
|
||||||
@field_validator("controlnet_conditioning_scale")
|
if values.use_extender and values.enable_controlnet:
|
||||||
@classmethod
|
logger.info("Extender is enabled, set controlnet_conditioning_scale=0")
|
||||||
def validate_field(cls, v: float, values):
|
values.controlnet_conditioning_scale = 0
|
||||||
use_extender = values.data["use_extender"]
|
|
||||||
enable_controlnet = values.data["enable_controlnet"]
|
|
||||||
if use_extender and enable_controlnet:
|
|
||||||
logger.info(f"Extender is enabled, set controlnet_conditioning_scale=0")
|
|
||||||
return 0
|
|
||||||
return v
|
|
||||||
|
|
||||||
@field_validator("sd_strength")
|
if values.use_extender:
|
||||||
@classmethod
|
logger.info("Extender is enabled, set sd_strength=1")
|
||||||
def validate_sd_strength(cls, v: float, values):
|
values.sd_strength = 1.0
|
||||||
use_extender = values.data["use_extender"]
|
|
||||||
if use_extender:
|
if values.enable_brushnet:
|
||||||
logger.info(f"Extender is enabled, set sd_strength=1")
|
logger.info("BrushNet is enabled, set enable_controlnet=False")
|
||||||
return 1.0
|
if values.enable_controlnet:
|
||||||
return v
|
values.enable_controlnet = False
|
||||||
|
if values.sd_lcm_lora:
|
||||||
|
logger.info("BrushNet is enabled, set sd_lcm_lora=False")
|
||||||
|
values.sd_lcm_lora = False
|
||||||
|
if values.sd_freeu:
|
||||||
|
logger.info("BrushNet is enabled, set sd_freeu=False")
|
||||||
|
values.sd_freeu = False
|
||||||
|
|
||||||
|
if values.enable_controlnet:
|
||||||
|
logger.info("ControlNet is enabled, set enable_brushnet=False")
|
||||||
|
if values.enable_brushnet:
|
||||||
|
values.enable_brushnet = False
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
class RunPluginRequest(BaseModel):
|
class RunPluginRequest(BaseModel):
|
||||||
|
89
iopaint/tests/test_brushnet.py
Normal file
89
iopaint/tests/test_brushnet.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from iopaint.const import SD_BRUSHNET_CHOICES
|
||||||
|
from iopaint.tests.utils import check_device, get_config, assert_equal
|
||||||
|
|
||||||
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from iopaint.model_manager import ModelManager
|
||||||
|
from iopaint.schema import HDStrategy, SDSampler, FREEUConfig
|
||||||
|
|
||||||
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
|
save_dir = current_dir / "result"
|
||||||
|
save_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||||
|
@pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m_karras])
|
||||||
|
def test_runway_brushnet(device, sampler):
|
||||||
|
sd_steps = check_device(device)
|
||||||
|
model = ModelManager(
|
||||||
|
name="runwayml/stable-diffusion-v1-5",
|
||||||
|
device=torch.device(device),
|
||||||
|
disable_nsfw=True,
|
||||||
|
sd_cpu_textencoder=False,
|
||||||
|
)
|
||||||
|
cfg = get_config(
|
||||||
|
strategy=HDStrategy.ORIGINAL,
|
||||||
|
prompt="face of a fox, sitting on a bench",
|
||||||
|
sd_steps=sd_steps,
|
||||||
|
sd_guidance_scale=7.5,
|
||||||
|
sd_freeu=True,
|
||||||
|
sd_freeu_config=FREEUConfig(),
|
||||||
|
enable_brushnet=True,
|
||||||
|
brushnet_method=SD_BRUSHNET_CHOICES[0]
|
||||||
|
)
|
||||||
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"brushnet_runway_1_5_freeu_device_{device}.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||||
|
@pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m_karras])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"name",
|
||||||
|
[
|
||||||
|
"v1-5-pruned-emaonly.safetensors",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_brushnet_local_file_path(device, sampler, name):
|
||||||
|
sd_steps = check_device(device)
|
||||||
|
model = ModelManager(
|
||||||
|
name=name,
|
||||||
|
device=torch.device(device),
|
||||||
|
disable_nsfw=True,
|
||||||
|
sd_cpu_textencoder=False,
|
||||||
|
cpu_offload=False,
|
||||||
|
)
|
||||||
|
cfg = get_config(
|
||||||
|
strategy=HDStrategy.ORIGINAL,
|
||||||
|
prompt="face of a fox, sitting on a bench",
|
||||||
|
sd_steps=sd_steps,
|
||||||
|
sd_seed=1234,
|
||||||
|
enable_brushnet=True,
|
||||||
|
brushnet_method=SD_BRUSHNET_CHOICES[1]
|
||||||
|
)
|
||||||
|
cfg.sd_sampler = sampler
|
||||||
|
name = f"device_{device}_{sampler}_{name}"
|
||||||
|
|
||||||
|
is_sdxl = "sd_xl" in name
|
||||||
|
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"brushnet_sd_local_model_{name}.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
fx=1.5 if is_sdxl else 1,
|
||||||
|
fy=1.5 if is_sdxl else 1,
|
||||||
|
)
|
@ -109,6 +109,84 @@ const DiffusionOptions = () => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const renderBrushNetSetting = () => {
|
||||||
|
if (!settings.model.support_brushnet) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col gap-4">
|
||||||
|
<div className="flex flex-col gap-4">
|
||||||
|
<div className="flex justify-between items-center pr-2">
|
||||||
|
<LabelTitle
|
||||||
|
text="BrushNet"
|
||||||
|
toolTip="BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion"
|
||||||
|
url="https://github.com/TencentARC/BrushNet"
|
||||||
|
/>
|
||||||
|
<Switch
|
||||||
|
id="brushnet"
|
||||||
|
checked={settings.enableBrushNet}
|
||||||
|
onCheckedChange={(value) => {
|
||||||
|
updateSettings({ enableBrushNet: value })
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div className="flex flex-col gap-1">
|
||||||
|
<RowContainer>
|
||||||
|
<Slider
|
||||||
|
className="w-[180px]"
|
||||||
|
defaultValue={[100]}
|
||||||
|
min={1}
|
||||||
|
max={100}
|
||||||
|
step={1}
|
||||||
|
disabled={!settings.enableBrushNet}
|
||||||
|
value={[Math.floor(settings.brushnetConditioningScale * 100)]}
|
||||||
|
onValueChange={(vals) =>
|
||||||
|
updateSettings({ brushnetConditioningScale: vals[0] / 100 })
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<NumberInput
|
||||||
|
id="controlnet-weight"
|
||||||
|
className="w-[60px] rounded-full"
|
||||||
|
disabled={!settings.enableBrushNet}
|
||||||
|
numberValue={settings.brushnetConditioningScale}
|
||||||
|
allowFloat={false}
|
||||||
|
onNumberValueChange={(val) => {
|
||||||
|
updateSettings({ brushnetConditioningScale: val })
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</RowContainer>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="pr-2">
|
||||||
|
<Select
|
||||||
|
defaultValue={settings.brushnetMethod}
|
||||||
|
value={settings.brushnetMethod}
|
||||||
|
onValueChange={(value) => {
|
||||||
|
updateSettings({ brushnetMethod: value })
|
||||||
|
}}
|
||||||
|
disabled={!settings.enableBrushNet}
|
||||||
|
>
|
||||||
|
<SelectTrigger>
|
||||||
|
<SelectValue placeholder="Select brushnet model" />
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent align="end">
|
||||||
|
<SelectGroup>
|
||||||
|
{Object.values(settings.model.brushnets).map((method) => (
|
||||||
|
<SelectItem key={method} value={method}>
|
||||||
|
{method}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectGroup>
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<Separator />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
const renderConterNetSetting = () => {
|
const renderConterNetSetting = () => {
|
||||||
if (!settings.model.support_controlnet) {
|
if (!settings.model.support_controlnet) {
|
||||||
return null
|
return null
|
||||||
@ -881,6 +959,7 @@ const DiffusionOptions = () => {
|
|||||||
{renderSeed()}
|
{renderSeed()}
|
||||||
{renderNegativePrompt()}
|
{renderNegativePrompt()}
|
||||||
<Separator />
|
<Separator />
|
||||||
|
{renderBrushNetSetting()}
|
||||||
{renderConterNetSetting()}
|
{renderConterNetSetting()}
|
||||||
{renderLCMLora()}
|
{renderLCMLora()}
|
||||||
{renderMaskBlur()}
|
{renderMaskBlur()}
|
||||||
|
@ -78,6 +78,9 @@ export default async function inpaint(
|
|||||||
controlnet_method: settings.controlnetMethod
|
controlnet_method: settings.controlnetMethod
|
||||||
? settings.controlnetMethod
|
? settings.controlnetMethod
|
||||||
: "",
|
: "",
|
||||||
|
enable_brushnet: settings.enableBrushNet,
|
||||||
|
brushnet_method: settings.brushnetMethod ? settings.brushnetMethod : "",
|
||||||
|
brushnet_conditioning_scale: settings.brushnetConditioningScale,
|
||||||
powerpaint_task: settings.showExtender
|
powerpaint_task: settings.showExtender
|
||||||
? PowerPaintTask.outpainting
|
? PowerPaintTask.outpainting
|
||||||
: settings.powerpaintTask,
|
: settings.powerpaintTask,
|
||||||
|
@ -99,6 +99,11 @@ export type Settings = {
|
|||||||
controlnetConditioningScale: number
|
controlnetConditioningScale: number
|
||||||
controlnetMethod: string
|
controlnetMethod: string
|
||||||
|
|
||||||
|
// BrushNet
|
||||||
|
enableBrushNet: boolean
|
||||||
|
brushnetMethod: string
|
||||||
|
brushnetConditioningScale: number
|
||||||
|
|
||||||
enableLCMLora: boolean
|
enableLCMLora: boolean
|
||||||
enableFreeu: boolean
|
enableFreeu: boolean
|
||||||
freeuConfig: FreeuConfig
|
freeuConfig: FreeuConfig
|
||||||
@ -306,15 +311,16 @@ const defaultValues: AppState = {
|
|||||||
path: "lama",
|
path: "lama",
|
||||||
model_type: "inpaint",
|
model_type: "inpaint",
|
||||||
support_controlnet: false,
|
support_controlnet: false,
|
||||||
|
support_brushnet: false,
|
||||||
support_strength: false,
|
support_strength: false,
|
||||||
support_outpainting: false,
|
support_outpainting: false,
|
||||||
controlnets: [],
|
controlnets: [],
|
||||||
|
brushnets: [],
|
||||||
support_freeu: false,
|
support_freeu: false,
|
||||||
support_lcm_lora: false,
|
support_lcm_lora: false,
|
||||||
is_single_file_diffusers: false,
|
is_single_file_diffusers: false,
|
||||||
need_prompt: false,
|
need_prompt: false,
|
||||||
},
|
},
|
||||||
enableControlnet: false,
|
|
||||||
showCropper: false,
|
showCropper: false,
|
||||||
showExtender: false,
|
showExtender: false,
|
||||||
extenderDirection: ExtenderDirection.xy,
|
extenderDirection: ExtenderDirection.xy,
|
||||||
@ -339,8 +345,12 @@ const defaultValues: AppState = {
|
|||||||
sdMatchHistograms: false,
|
sdMatchHistograms: false,
|
||||||
sdScale: 1.0,
|
sdScale: 1.0,
|
||||||
p2pImageGuidanceScale: 1.5,
|
p2pImageGuidanceScale: 1.5,
|
||||||
controlnetConditioningScale: 0.4,
|
enableControlnet: false,
|
||||||
controlnetMethod: "lllyasviel/control_v11p_sd15_canny",
|
controlnetMethod: "lllyasviel/control_v11p_sd15_canny",
|
||||||
|
controlnetConditioningScale: 0.4,
|
||||||
|
enableBrushNet: false,
|
||||||
|
brushnetMethod: "random_mask",
|
||||||
|
brushnetConditioningScale: 1.0,
|
||||||
enableLCMLora: false,
|
enableLCMLora: false,
|
||||||
enableFreeu: false,
|
enableFreeu: false,
|
||||||
freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 },
|
freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 },
|
||||||
@ -1076,7 +1086,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
})),
|
})),
|
||||||
{
|
{
|
||||||
name: "ZUSTAND_STATE", // name of the item in the storage (must be unique)
|
name: "ZUSTAND_STATE", // name of the item in the storage (must be unique)
|
||||||
version: 1,
|
version: 2,
|
||||||
partialize: (state) =>
|
partialize: (state) =>
|
||||||
Object.fromEntries(
|
Object.fromEntries(
|
||||||
Object.entries(state).filter(([key]) =>
|
Object.entries(state).filter(([key]) =>
|
||||||
|
@ -48,7 +48,9 @@ export interface ModelInfo {
|
|||||||
support_strength: boolean
|
support_strength: boolean
|
||||||
support_outpainting: boolean
|
support_outpainting: boolean
|
||||||
support_controlnet: boolean
|
support_controlnet: boolean
|
||||||
|
support_brushnet: boolean
|
||||||
controlnets: string[]
|
controlnets: string[]
|
||||||
|
brushnets: string[]
|
||||||
support_freeu: boolean
|
support_freeu: boolean
|
||||||
support_lcm_lora: boolean
|
support_lcm_lora: boolean
|
||||||
need_prompt: boolean
|
need_prompt: boolean
|
||||||
|
Loading…
Reference in New Issue
Block a user