make brushnet work
This commit is contained in:
parent
35f12d5b9b
commit
0a262fa811
@ -6,7 +6,6 @@ KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
||||
POWERPAINT_NAME = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
|
||||
ANYTEXT_NAME = "Sanster/AnyText"
|
||||
|
||||
|
||||
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
|
||||
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
|
||||
DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
|
||||
@ -62,6 +61,11 @@ SD_CONTROLNET_CHOICES: List[str] = [
|
||||
"lllyasviel/control_v11f1p_sd15_depth",
|
||||
]
|
||||
|
||||
SD_BRUSHNET_CHOICES: List[str] = [
|
||||
"Sanster/brushnet_random_mask",
|
||||
"Sanster/brushnet_segmentation_mask"
|
||||
]
|
||||
|
||||
SD2_CONTROLNET_CHOICES = [
|
||||
"thibaud/controlnet-sd21-canny-diffusers",
|
||||
"thibaud/controlnet-sd21-depth-diffusers",
|
||||
|
@ -1,3 +1,4 @@
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
from functools import lru_cache
|
||||
@ -92,7 +93,7 @@ def get_sdxl_model_type(model_abs_path: str) -> ModelType:
|
||||
else:
|
||||
model_type = ModelType.DIFFUSERS_SDXL
|
||||
except ValueError as e:
|
||||
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
|
||||
if "but got torch.Size([320, 4, 3, 3])" in str(e):
|
||||
model_type = ModelType.DIFFUSERS_SDXL
|
||||
else:
|
||||
raise e
|
||||
@ -192,7 +193,9 @@ def scan_diffusers_models() -> List[ModelInfo]:
|
||||
cache_dir = Path(HF_HUB_CACHE)
|
||||
# logger.info(f"Scanning diffusers models in {cache_dir}")
|
||||
diffusers_model_names = []
|
||||
for it in cache_dir.glob("**/*/model_index.json"):
|
||||
model_index_files = glob.glob(os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True)
|
||||
for it in model_index_files:
|
||||
it = Path(it)
|
||||
with open(it, "r", encoding="utf-8") as f:
|
||||
try:
|
||||
data = json.load(f)
|
||||
@ -238,7 +241,9 @@ def _scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
|
||||
cache_dir = Path(cache_dir)
|
||||
available_models = []
|
||||
diffusers_model_names = []
|
||||
for it in cache_dir.glob("**/*/model_index.json"):
|
||||
model_index_files = glob.glob(os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True)
|
||||
for it in model_index_files:
|
||||
it = Path(it)
|
||||
with open(it, "r", encoding="utf-8") as f:
|
||||
try:
|
||||
data = json.load(f)
|
||||
|
931
iopaint/model/brushnet/brushnet.py
Normal file
931
iopaint/model/brushnet/brushnet.py
Normal file
@ -0,0 +1,931 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.utils import BaseOutput, logging
|
||||
from diffusers.models.attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, \
|
||||
TimestepEmbedding, Timesteps
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.models.unets.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
DownBlock2D, get_down_block, get_up_block,
|
||||
)
|
||||
|
||||
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||
from .unet_2d_blocks import MidBlock2D
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class BrushNetOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`BrushNetModel`].
|
||||
|
||||
Args:
|
||||
up_block_res_samples (`tuple[torch.Tensor]`):
|
||||
A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
|
||||
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
||||
used to condition the original UNet's upsampling activations.
|
||||
down_block_res_samples (`tuple[torch.Tensor]`):
|
||||
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
||||
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
||||
used to condition the original UNet's downsampling activations.
|
||||
mid_down_block_re_sample (`torch.Tensor`):
|
||||
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
||||
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
||||
Output can be used to condition the original UNet's middle block activation.
|
||||
"""
|
||||
|
||||
up_block_res_samples: Tuple[torch.Tensor]
|
||||
down_block_res_samples: Tuple[torch.Tensor]
|
||||
mid_block_res_sample: torch.Tensor
|
||||
|
||||
|
||||
class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
A BrushNet model.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to 4):
|
||||
The number of channels in the input sample.
|
||||
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||
Whether to flip the sin to cos in the time embedding.
|
||||
freq_shift (`int`, defaults to 0):
|
||||
The frequency shift to apply to the time embedding.
|
||||
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
||||
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
||||
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
||||
The tuple of upsample blocks to use.
|
||||
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, defaults to 2):
|
||||
The number of layers per block.
|
||||
downsample_padding (`int`, defaults to 1):
|
||||
The padding to use for the downsampling convolution.
|
||||
mid_block_scale_factor (`float`, defaults to 1):
|
||||
The scale factor to use for the mid block.
|
||||
act_fn (`str`, defaults to "silu"):
|
||||
The activation function to use.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
||||
in post-processing.
|
||||
norm_eps (`float`, defaults to 1e-5):
|
||||
The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, defaults to 1280):
|
||||
The dimension of the cross attention features.
|
||||
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||
dimension to `cross_attention_dim`.
|
||||
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||
The dimension of the attention heads.
|
||||
use_linear_projection (`bool`, defaults to `False`):
|
||||
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||
class conditioning with `class_embed_type` equal to `None`.
|
||||
upcast_attention (`bool`, defaults to `False`):
|
||||
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
||||
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
||||
`class_embed_type="projection"`.
|
||||
brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
||||
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
||||
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
||||
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||
global_pool_conditions (`bool`, defaults to `False`):
|
||||
TODO(Patrick) - unused parameter.
|
||||
addition_embed_type_num_heads (`int`, defaults to 64):
|
||||
The number of heads to use for the `TextTimeEmbedding` layer.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 5,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2D",
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
"UpBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
brushnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||
# which is why we correct for the naming here.
|
||||
num_attention_heads = num_attention_heads or attention_head_dim
|
||||
|
||||
# Check inputs
|
||||
if len(down_block_types) != len(up_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||
)
|
||||
|
||||
if len(block_out_channels) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in_condition = nn.Conv2d(
|
||||
in_channels + conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel,
|
||||
padding=conv_in_padding
|
||||
)
|
||||
|
||||
# time
|
||||
time_embed_dim = block_out_channels[0] * 4
|
||||
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||
timestep_input_dim = block_out_channels[0]
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
timestep_input_dim,
|
||||
time_embed_dim,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
||||
encoder_hid_dim_type = "text_proj"
|
||||
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
||||
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
||||
|
||||
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
||||
)
|
||||
|
||||
if encoder_hid_dim_type == "text_proj":
|
||||
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
||||
elif encoder_hid_dim_type == "text_image_proj":
|
||||
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
||||
self.encoder_hid_proj = TextImageProjection(
|
||||
text_embed_dim=encoder_hid_dim,
|
||||
image_embed_dim=cross_attention_dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
|
||||
elif encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
||||
)
|
||||
else:
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
# class embedding
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
elif class_embed_type == "projection":
|
||||
if projection_class_embeddings_input_dim is None:
|
||||
raise ValueError(
|
||||
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
||||
)
|
||||
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
||||
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
||||
# 2. it projects from an arbitrary input dimension.
|
||||
#
|
||||
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
||||
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
||||
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
||||
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
if addition_embed_type == "text":
|
||||
if encoder_hid_dim is not None:
|
||||
text_time_embedding_from_dim = encoder_hid_dim
|
||||
else:
|
||||
text_time_embedding_from_dim = cross_attention_dim
|
||||
|
||||
self.add_embedding = TextTimeEmbedding(
|
||||
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
||||
)
|
||||
elif addition_embed_type == "text_image":
|
||||
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
||||
self.add_embedding = TextImageTimeEmbedding(
|
||||
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
||||
)
|
||||
elif addition_embed_type == "text_time":
|
||||
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
||||
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
|
||||
elif addition_embed_type is not None:
|
||||
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.brushnet_down_blocks = nn.ModuleList([])
|
||||
|
||||
if isinstance(only_cross_attention, bool):
|
||||
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||
|
||||
if isinstance(attention_head_dim, int):
|
||||
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||
|
||||
if isinstance(num_attention_heads, int):
|
||||
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||
|
||||
# down
|
||||
output_channel = block_out_channels[0]
|
||||
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_down_blocks.append(brushnet_block)
|
||||
|
||||
for i, down_block_type in enumerate(down_block_types):
|
||||
input_channel = output_channel
|
||||
output_channel = block_out_channels[i]
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = get_down_block(
|
||||
down_block_type,
|
||||
num_layers=layers_per_block,
|
||||
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_downsample=not is_final_block,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
downsample_padding=downsample_padding,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
for _ in range(layers_per_block):
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_down_blocks.append(brushnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_down_blocks.append(brushnet_block)
|
||||
|
||||
# mid
|
||||
mid_block_channel = block_out_channels[-1]
|
||||
|
||||
brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_mid_block = brushnet_block
|
||||
|
||||
self.mid_block = MidBlock2D(
|
||||
in_channels=mid_block_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
dropout=0.0,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
resnet_groups=norm_num_groups,
|
||||
use_linear_projection=use_linear_projection,
|
||||
)
|
||||
|
||||
# count how many layers upsample the images
|
||||
self.num_upsamplers = 0
|
||||
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||
reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
|
||||
only_cross_attention = list(reversed(only_cross_attention))
|
||||
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
self.brushnet_up_blocks = nn.ModuleList([])
|
||||
|
||||
for i, up_block_type in enumerate(up_block_types):
|
||||
is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
|
||||
# add upsample block for all BUT final layer
|
||||
if not is_final_block:
|
||||
add_upsample = True
|
||||
self.num_upsamplers += 1
|
||||
else:
|
||||
add_upsample = False
|
||||
|
||||
up_block = get_up_block(
|
||||
up_block_type,
|
||||
num_layers=layers_per_block + 1,
|
||||
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
||||
in_channels=input_channel,
|
||||
out_channels=output_channel,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=time_embed_dim,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
resolution_idx=i,
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=reversed_num_attention_heads[i],
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
)
|
||||
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
for _ in range(layers_per_block + 1):
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_up_blocks.append(brushnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_up_blocks.append(brushnet_block)
|
||||
|
||||
@classmethod
|
||||
def from_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
brushnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
conditioning_channels: int = 5,
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
|
||||
|
||||
Parameters:
|
||||
unet (`UNet2DConditionModel`):
|
||||
The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
|
||||
where applicable.
|
||||
"""
|
||||
transformer_layers_per_block = (
|
||||
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
||||
)
|
||||
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
||||
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
||||
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
||||
addition_time_embed_dim = (
|
||||
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
||||
)
|
||||
|
||||
brushnet = cls(
|
||||
in_channels=unet.config.in_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||
freq_shift=unet.config.freq_shift,
|
||||
down_block_types=['DownBlock2D', 'DownBlock2D', 'DownBlock2D', 'DownBlock2D'],
|
||||
mid_block_type='MidBlock2D',
|
||||
up_block_types=['UpBlock2D', 'UpBlock2D', 'UpBlock2D', 'UpBlock2D'],
|
||||
only_cross_attention=unet.config.only_cross_attention,
|
||||
block_out_channels=unet.config.block_out_channels,
|
||||
layers_per_block=unet.config.layers_per_block,
|
||||
downsample_padding=unet.config.downsample_padding,
|
||||
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
||||
act_fn=unet.config.act_fn,
|
||||
norm_num_groups=unet.config.norm_num_groups,
|
||||
norm_eps=unet.config.norm_eps,
|
||||
cross_attention_dim=unet.config.cross_attention_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
encoder_hid_dim=encoder_hid_dim,
|
||||
encoder_hid_dim_type=encoder_hid_dim_type,
|
||||
attention_head_dim=unet.config.attention_head_dim,
|
||||
num_attention_heads=unet.config.num_attention_heads,
|
||||
use_linear_projection=unet.config.use_linear_projection,
|
||||
class_embed_type=unet.config.class_embed_type,
|
||||
addition_embed_type=addition_embed_type,
|
||||
addition_time_embed_dim=addition_time_embed_dim,
|
||||
num_class_embeds=unet.config.num_class_embeds,
|
||||
upcast_attention=unet.config.upcast_attention,
|
||||
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
||||
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
||||
brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
|
||||
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||
)
|
||||
|
||||
if load_weights_from_unet:
|
||||
conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
|
||||
conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
|
||||
conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
|
||||
brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
|
||||
brushnet.conv_in_condition.bias = unet.conv_in.bias
|
||||
|
||||
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
||||
|
||||
if brushnet.class_embedding:
|
||||
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
||||
|
||||
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
|
||||
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
|
||||
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
|
||||
|
||||
return brushnet
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
||||
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
||||
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||
must be a multiple of `slice_size`.
|
||||
"""
|
||||
sliceable_head_dims = []
|
||||
|
||||
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_retrieve_sliceable_dims(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in self.children():
|
||||
fn_recursive_retrieve_sliceable_dims(module)
|
||||
|
||||
num_sliceable_layers = len(sliceable_head_dims)
|
||||
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||
elif slice_size == "max":
|
||||
# make smallest slice possible
|
||||
slice_size = num_sliceable_layers * [1]
|
||||
|
||||
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||
)
|
||||
|
||||
for i in range(len(slice_size)):
|
||||
size = slice_size[i]
|
||||
dim = sliceable_head_dims[i]
|
||||
if size is not None and size > dim:
|
||||
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_attention_slice(child, slice_size)
|
||||
|
||||
reversed_slice_size = list(reversed(slice_size))
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
brushnet_cond: torch.FloatTensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
|
||||
"""
|
||||
The [`BrushNetModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor.
|
||||
timestep (`Union[torch.Tensor, float, int]`):
|
||||
The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
The encoder hidden states.
|
||||
brushnet_cond (`torch.FloatTensor`):
|
||||
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||
conditioning_scale (`float`, defaults to `1.0`):
|
||||
The scale factor for BrushNet outputs.
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
||||
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
||||
embeddings.
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
added_cond_kwargs (`dict`):
|
||||
Additional conditions for the Stable Diffusion XL UNet.
|
||||
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||
guess_mode (`bool`, defaults to `False`):
|
||||
In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
|
||||
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
return_dict (`bool`, defaults to `True`):
|
||||
Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.brushnet.BrushNetOutput`] **or** `tuple`:
|
||||
If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
|
||||
returned where the first element is the sample tensor.
|
||||
"""
|
||||
# check channel order
|
||||
channel_order = self.config.brushnet_conditioning_channel_order
|
||||
|
||||
if channel_order == "rgb":
|
||||
# in rgb order by default
|
||||
...
|
||||
elif channel_order == "bgr":
|
||||
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
|
||||
else:
|
||||
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 1. time
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = sample.device.type == "mps"
|
||||
if isinstance(timestep, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
|
||||
t_emb = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=sample.dtype)
|
||||
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
if self.config.addition_embed_type is not None:
|
||||
if self.config.addition_embed_type == "text":
|
||||
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||
|
||||
elif self.config.addition_embed_type == "text_time":
|
||||
if "text_embeds" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||
if "time_ids" not in added_cond_kwargs:
|
||||
raise ValueError(
|
||||
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||
)
|
||||
time_ids = added_cond_kwargs.get("time_ids")
|
||||
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||
|
||||
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||
add_embeds = add_embeds.to(emb.dtype)
|
||||
aug_emb = self.add_embedding(add_embeds)
|
||||
|
||||
emb = emb + aug_emb if aug_emb is not None else emb
|
||||
|
||||
# 2. pre-process
|
||||
brushnet_cond = torch.concat([sample, brushnet_cond], 1)
|
||||
sample = self.conv_in_condition(brushnet_cond)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. PaintingNet down blocks
|
||||
brushnet_down_block_res_samples = ()
|
||||
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
|
||||
down_block_res_sample = brushnet_down_block(down_block_res_sample)
|
||||
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
|
||||
|
||||
# 5. mid
|
||||
if self.mid_block is not None:
|
||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
)
|
||||
else:
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# 6. BrushNet mid blocks
|
||||
brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
|
||||
|
||||
# 7. up
|
||||
up_block_res_samples = ()
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||
sample, up_res_samples = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
return_res_samples=True
|
||||
)
|
||||
else:
|
||||
sample, up_res_samples = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
upsample_size=upsample_size,
|
||||
return_res_samples=True
|
||||
)
|
||||
|
||||
up_block_res_samples += up_res_samples
|
||||
|
||||
# 8. BrushNet up blocks
|
||||
brushnet_up_block_res_samples = ()
|
||||
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
|
||||
up_block_res_sample = brushnet_up_block(up_block_res_sample)
|
||||
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
|
||||
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0,
|
||||
len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
|
||||
device=sample.device) # 0.1 to 1.0
|
||||
scales = scales * conditioning_scale
|
||||
|
||||
brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples,
|
||||
scales[:len(
|
||||
brushnet_down_block_res_samples)])]
|
||||
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
|
||||
brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples,
|
||||
scales[
|
||||
len(brushnet_down_block_res_samples) + 1:])]
|
||||
else:
|
||||
brushnet_down_block_res_samples = [sample * conditioning_scale for sample in
|
||||
brushnet_down_block_res_samples]
|
||||
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
|
||||
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
brushnet_down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
|
||||
]
|
||||
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
brushnet_up_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
|
||||
]
|
||||
|
||||
if not return_dict:
|
||||
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
|
||||
|
||||
return BrushNetOutput(
|
||||
down_block_res_samples=brushnet_down_block_res_samples,
|
||||
mid_block_res_sample=brushnet_mid_block_res_sample,
|
||||
up_block_res_samples=brushnet_up_block_res_samples
|
||||
)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
nn.init.zeros_(p)
|
||||
return module
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
BrushNetModel.from_pretrained("/Users/cwq/data/models/brushnet/brushnet_random_mask", variant='fp16',
|
||||
use_safetensors=True)
|
322
iopaint/model/brushnet/brushnet_unet_forward.py
Normal file
322
iopaint/model/brushnet/brushnet_unet_forward.py
Normal file
@ -0,0 +1,322 @@
|
||||
from typing import Union, Optional, Dict, Any, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||
from diffusers.utils import USE_PEFT_BACKEND, unscale_lora_layers, deprecate, scale_lora_layers
|
||||
|
||||
|
||||
def brushnet_unet_forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
|
||||
up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
r"""
|
||||
The [`UNet2DConditionModel`] forward method.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`):
|
||||
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
||||
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||
encoder_hidden_states (`torch.FloatTensor`):
|
||||
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
||||
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
||||
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
||||
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||
negative values to the attention scores corresponding to "discard" tokens.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
added_cond_kwargs: (`dict`, *optional*):
|
||||
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
||||
are passed along to the UNet blocks.
|
||||
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
||||
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
||||
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
||||
A tensor that if specified is added to the residual of the middle unet block.
|
||||
encoder_attention_mask (`torch.Tensor`):
|
||||
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
||||
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
||||
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||
tuple.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
||||
added_cond_kwargs: (`dict`, *optional*):
|
||||
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
||||
are passed along to the UNet blocks.
|
||||
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
|
||||
example from ControlNet side model(s)
|
||||
mid_block_additional_residual (`torch.Tensor`, *optional*):
|
||||
additional residual to be added to UNet mid block output, for example from ControlNet side model
|
||||
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
||||
|
||||
Returns:
|
||||
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
||||
a `tuple` is returned where the first element is the sample tensor.
|
||||
"""
|
||||
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
||||
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||
# on the fly if necessary.
|
||||
default_overall_up_factor = 2 ** self.num_upsamplers
|
||||
|
||||
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
|
||||
for dim in sample.shape[-2:]:
|
||||
if dim % default_overall_up_factor != 0:
|
||||
# Forward upsample size to force interpolation output size.
|
||||
forward_upsample_size = True
|
||||
break
|
||||
|
||||
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
||||
# expects mask of shape:
|
||||
# [batch, key_tokens]
|
||||
# adds singleton query_tokens dimension:
|
||||
# [batch, 1, key_tokens]
|
||||
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||
if attention_mask is not None:
|
||||
# assume that mask is expressed as:
|
||||
# (1 = keep, 0 = discard)
|
||||
# convert mask into a bias that can be added to attention scores:
|
||||
# (keep = +0, discard = -10000.0)
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None:
|
||||
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
||||
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
|
||||
# 1. time
|
||||
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
|
||||
emb = self.time_embedding(t_emb, timestep_cond)
|
||||
aug_emb = None
|
||||
|
||||
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
|
||||
if class_emb is not None:
|
||||
if self.config.class_embeddings_concat:
|
||||
emb = torch.cat([emb, class_emb], dim=-1)
|
||||
else:
|
||||
emb = emb + class_emb
|
||||
|
||||
aug_emb = self.get_aug_embed(
|
||||
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
||||
)
|
||||
if self.config.addition_embed_type == "image_hint":
|
||||
aug_emb, hint = aug_emb
|
||||
sample = torch.cat([sample, hint], dim=1)
|
||||
|
||||
emb = emb + aug_emb if aug_emb is not None else emb
|
||||
|
||||
if self.time_embed_act is not None:
|
||||
emb = self.time_embed_act(emb)
|
||||
|
||||
encoder_hidden_states = self.process_encoder_hidden_states(
|
||||
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
||||
)
|
||||
|
||||
# 2. pre-process
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
# 2.5 GLIGEN position net
|
||||
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
||||
cross_attention_kwargs = cross_attention_kwargs.copy()
|
||||
gligen_args = cross_attention_kwargs.pop("gligen")
|
||||
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
||||
|
||||
# 3. down
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
if USE_PEFT_BACKEND:
|
||||
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||
scale_lora_layers(self, lora_scale)
|
||||
|
||||
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
||||
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
||||
is_adapter = down_intrablock_additional_residuals is not None
|
||||
# maintain backward compatibility for legacy usage, where
|
||||
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
||||
# but can only use one or the other
|
||||
is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None
|
||||
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
|
||||
deprecate(
|
||||
"T2I should not use down_block_additional_residuals",
|
||||
"1.3.0",
|
||||
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
||||
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
||||
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
||||
standard_warn=False,
|
||||
)
|
||||
down_intrablock_additional_residuals = down_block_additional_residuals
|
||||
is_adapter = True
|
||||
|
||||
down_block_res_samples = (sample,)
|
||||
|
||||
if is_brushnet:
|
||||
sample = sample + down_block_add_samples.pop(0)
|
||||
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
# For t2i-adapter CrossAttnDownBlock2D
|
||||
additional_residuals = {}
|
||||
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
if is_brushnet and len(down_block_add_samples) > 0:
|
||||
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
|
||||
for _ in range(
|
||||
len(downsample_block.resnets) + (downsample_block.downsamplers != None))]
|
||||
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
**additional_residuals,
|
||||
)
|
||||
else:
|
||||
additional_residuals = {}
|
||||
if is_brushnet and len(down_block_add_samples) > 0:
|
||||
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
|
||||
for _ in range(
|
||||
len(downsample_block.resnets) + (downsample_block.downsamplers != None))]
|
||||
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale,
|
||||
**additional_residuals)
|
||||
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||
sample += down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
if is_controlnet:
|
||||
new_down_block_res_samples = ()
|
||||
|
||||
for down_block_res_sample, down_block_additional_residual in zip(
|
||||
down_block_res_samples, down_block_additional_residuals
|
||||
):
|
||||
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
||||
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
||||
|
||||
down_block_res_samples = new_down_block_res_samples
|
||||
|
||||
# 4. mid
|
||||
if self.mid_block is not None:
|
||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
)
|
||||
else:
|
||||
sample = self.mid_block(sample, emb)
|
||||
|
||||
# To support T2I-Adapter-XL
|
||||
if (
|
||||
is_adapter
|
||||
and len(down_intrablock_additional_residuals) > 0
|
||||
and sample.shape == down_intrablock_additional_residuals[0].shape
|
||||
):
|
||||
sample += down_intrablock_additional_residuals.pop(0)
|
||||
|
||||
if is_controlnet:
|
||||
sample = sample + mid_block_additional_residual
|
||||
|
||||
if is_brushnet:
|
||||
sample = sample + mid_block_add_sample
|
||||
|
||||
# 5. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||
additional_residuals = {}
|
||||
if is_brushnet and len(up_block_add_samples) > 0:
|
||||
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
|
||||
for _ in range(
|
||||
len(upsample_block.resnets) + (upsample_block.upsamplers != None))]
|
||||
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
**additional_residuals,
|
||||
)
|
||||
else:
|
||||
additional_residuals = {}
|
||||
if is_brushnet and len(up_block_add_samples) > 0:
|
||||
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
|
||||
for _ in range(
|
||||
len(upsample_block.resnets) + (upsample_block.upsamplers != None))]
|
||||
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
upsample_size=upsample_size,
|
||||
scale=lora_scale,
|
||||
**additional_residuals,
|
||||
)
|
||||
|
||||
# 6. post-process
|
||||
if self.conv_norm_out:
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if USE_PEFT_BACKEND:
|
||||
# remove `lora_scale` from each PEFT layer
|
||||
unscale_lora_layers(self, lora_scale)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet2DConditionOutput(sample=sample)
|
157
iopaint/model/brushnet/brushnet_wrapper.py
Normal file
157
iopaint/model/brushnet/brushnet_wrapper.py
Normal file
@ -0,0 +1,157 @@
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import torch
|
||||
from loguru import logger
|
||||
import numpy as np
|
||||
|
||||
from ..base import DiffusionInpaintModel
|
||||
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||
from ..original_sd_configs import get_config_files
|
||||
from ..utils import (
|
||||
handle_from_pretrained_exceptions,
|
||||
get_torch_dtype,
|
||||
enable_low_mem,
|
||||
is_local_files_only,
|
||||
)
|
||||
from .brushnet import BrushNetModel
|
||||
from .brushnet_unet_forward import brushnet_unet_forward
|
||||
from .unet_2d_blocks import CrossAttnDownBlock2D_forward, DownBlock2D_forward, CrossAttnUpBlock2D_forward, \
|
||||
UpBlock2D_forward
|
||||
from ...schema import InpaintRequest, ModelType
|
||||
|
||||
|
||||
class BrushNetWrapper(DiffusionInpaintModel):
|
||||
pad_mod = 8
|
||||
min_size = 512
|
||||
|
||||
def init_model(self, device: torch.device, **kwargs):
|
||||
from .pipeline_brushnet import StableDiffusionBrushNetPipeline
|
||||
self.model_info = kwargs["model_info"]
|
||||
self.brushnet_method = kwargs["brushnet_method"]
|
||||
|
||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
model_kwargs = {
|
||||
**kwargs.get("pipe_components", {}),
|
||||
"local_files_only": is_local_files_only(**kwargs),
|
||||
}
|
||||
self.local_files_only = model_kwargs["local_files_only"]
|
||||
|
||||
disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
|
||||
"cpu_offload", False
|
||||
)
|
||||
if disable_nsfw_checker:
|
||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||
model_kwargs.update(
|
||||
dict(
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info(f"Loading BrushNet model from {self.brushnet_method}")
|
||||
brushnet = BrushNetModel.from_pretrained(self.brushnet_method, torch_dtype=torch_dtype)
|
||||
|
||||
if self.model_info.is_single_file_diffusers:
|
||||
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||
model_kwargs["num_in_channels"] = 4
|
||||
else:
|
||||
model_kwargs["num_in_channels"] = 9
|
||||
|
||||
self.model = StableDiffusionBrushNetPipeline.from_single_file(
|
||||
self.model_id_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
load_safety_checker=not disable_nsfw_checker,
|
||||
config_files=get_config_files(),
|
||||
brushnet=brushnet,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
self.model = handle_from_pretrained_exceptions(
|
||||
StableDiffusionBrushNetPipeline.from_pretrained,
|
||||
pretrained_model_name_or_path=self.model_id_or_path,
|
||||
variant="fp16",
|
||||
torch_dtype=torch_dtype,
|
||||
brushnet=brushnet,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
enable_low_mem(self.model, kwargs.get("low_mem", False))
|
||||
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
logger.info("Enable sequential cpu offload")
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
else:
|
||||
self.model = self.model.to(device)
|
||||
if kwargs["sd_cpu_textencoder"]:
|
||||
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||
self.model.text_encoder = CPUTextEncoderWrapper(
|
||||
self.model.text_encoder, torch_dtype
|
||||
)
|
||||
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
|
||||
self.model.unet.forward = brushnet_unet_forward.__get__(self.model.unet, self.model.unet.__class__)
|
||||
|
||||
for down_block in self.model.brushnet.down_blocks:
|
||||
down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__)
|
||||
for up_block in self.model.brushnet.up_blocks:
|
||||
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
|
||||
|
||||
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
|
||||
for down_block in self.model.unet.down_blocks:
|
||||
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
|
||||
down_block.forward = CrossAttnDownBlock2D_forward.__get__(down_block, down_block.__class__)
|
||||
else:
|
||||
down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__)
|
||||
|
||||
for up_block in self.model.unet.up_blocks:
|
||||
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
||||
up_block.forward = CrossAttnUpBlock2D_forward.__get__(up_block, up_block.__class__)
|
||||
else:
|
||||
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
|
||||
|
||||
def switch_brushnet_method(self, new_method: str):
|
||||
self.brushnet_method = new_method
|
||||
brushnet = BrushNetModel.from_pretrained(
|
||||
new_method,
|
||||
resume_download=True,
|
||||
local_files_only=self.local_files_only,
|
||||
torch_dtype=self.torch_dtype,
|
||||
).to(self.model.device)
|
||||
self.model.brushnet = brushnet
|
||||
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
mask: [H, W, 1] 255 means area to repaint
|
||||
return: BGR IMAGE
|
||||
"""
|
||||
self.set_scheduler(config)
|
||||
|
||||
img_h, img_w = image.shape[:2]
|
||||
normalized_mask = mask[:, :].astype("float32") / 255.0
|
||||
image = image * (1 - normalized_mask)
|
||||
image = image.astype(np.uint8)
|
||||
output = self.model(
|
||||
image=PIL.Image.fromarray(image),
|
||||
prompt=config.prompt,
|
||||
negative_prompt=config.negative_prompt,
|
||||
mask=PIL.Image.fromarray(mask[:, :, -1], mode="L").convert("RGB"),
|
||||
num_inference_steps=config.sd_steps,
|
||||
# strength=config.sd_strength,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np",
|
||||
callback_on_step_end=self.callback,
|
||||
height=img_h,
|
||||
width=img_w,
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
brushnet_conditioning_scale=config.brushnet_conditioning_scale,
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||
return output
|
1279
iopaint/model/brushnet/pipeline_brushnet.py
Normal file
1279
iopaint/model/brushnet/pipeline_brushnet.py
Normal file
File diff suppressed because it is too large
Load Diff
388
iopaint/model/brushnet/unet_2d_blocks.py
Normal file
388
iopaint/model/brushnet/unet_2d_blocks.py
Normal file
@ -0,0 +1,388 @@
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from diffusers.models.resnet import ResnetBlock2D
|
||||
from diffusers.utils import is_torch_version
|
||||
from diffusers.utils.torch_utils import apply_freeu
|
||||
from torch import nn
|
||||
|
||||
|
||||
class MidBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor: float = 1.0,
|
||||
use_linear_projection: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.has_cross_attention = False
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
]
|
||||
|
||||
for i in range(num_layers):
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
lora_scale = 1.0
|
||||
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
||||
for resnet in self.resnets[1:]:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def DownBlock2D_forward(
|
||||
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0,
|
||||
down_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
|
||||
if down_block_add_samples is not None:
|
||||
hidden_states = hidden_states + down_block_add_samples.pop(0)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, scale=scale)
|
||||
|
||||
if down_block_add_samples is not None:
|
||||
hidden_states = hidden_states + down_block_add_samples.pop(0) # todo: add before or after
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
def CrossAttnDownBlock2D_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
additional_residuals: Optional[torch.FloatTensor] = None,
|
||||
down_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||
output_states = ()
|
||||
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
|
||||
blocks = list(zip(self.resnets, self.attentions))
|
||||
|
||||
for i, (resnet, attn) in enumerate(blocks):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
||||
if i == len(blocks) - 1 and additional_residuals is not None:
|
||||
hidden_states = hidden_states + additional_residuals
|
||||
|
||||
if down_block_add_samples is not None:
|
||||
hidden_states = hidden_states + down_block_add_samples.pop(0)
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, scale=lora_scale)
|
||||
|
||||
if down_block_add_samples is not None:
|
||||
hidden_states = hidden_states + down_block_add_samples.pop(0) # todo: add before or after
|
||||
|
||||
output_states = output_states + (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
def CrossAttnUpBlock2D_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||
return_res_samples: Optional[bool] = False,
|
||||
up_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
and getattr(self, "b1", None)
|
||||
and getattr(self, "b2", None)
|
||||
)
|
||||
if return_res_samples:
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# FreeU: Only operate on the first two stages
|
||||
if is_freeu_enabled:
|
||||
hidden_states, res_hidden_states = apply_freeu(
|
||||
self.resolution_idx,
|
||||
hidden_states,
|
||||
res_hidden_states,
|
||||
s1=self.s1,
|
||||
s2=self.s2,
|
||||
b1=self.b1,
|
||||
b2=self.b2,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet),
|
||||
hidden_states,
|
||||
temb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
attention_mask=attention_mask,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
if return_res_samples:
|
||||
output_states = output_states + (hidden_states,)
|
||||
if up_block_add_samples is not None:
|
||||
hidden_states = hidden_states + up_block_add_samples.pop(0)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
|
||||
if return_res_samples:
|
||||
output_states = output_states + (hidden_states,)
|
||||
if up_block_add_samples is not None:
|
||||
hidden_states = hidden_states + up_block_add_samples.pop(0)
|
||||
|
||||
if return_res_samples:
|
||||
return hidden_states, output_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
def UpBlock2D_forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||
temb: Optional[torch.FloatTensor] = None,
|
||||
upsample_size: Optional[int] = None,
|
||||
scale: float = 1.0,
|
||||
return_res_samples: Optional[bool] = False,
|
||||
up_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
is_freeu_enabled = (
|
||||
getattr(self, "s1", None)
|
||||
and getattr(self, "s2", None)
|
||||
and getattr(self, "b1", None)
|
||||
and getattr(self, "b2", None)
|
||||
)
|
||||
if return_res_samples:
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
|
||||
# FreeU: Only operate on the first two stages
|
||||
if is_freeu_enabled:
|
||||
hidden_states, res_hidden_states = apply_freeu(
|
||||
self.resolution_idx,
|
||||
hidden_states,
|
||||
res_hidden_states,
|
||||
s1=self.s1,
|
||||
s2=self.s2,
|
||||
b1=self.b1,
|
||||
b2=self.b2,
|
||||
)
|
||||
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(resnet), hidden_states, temb
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||
|
||||
if return_res_samples:
|
||||
output_states = output_states + (hidden_states,)
|
||||
if up_block_add_samples is not None:
|
||||
hidden_states = hidden_states + up_block_add_samples.pop(0) # todo: add before or after
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
|
||||
|
||||
if return_res_samples:
|
||||
output_states = output_states + (hidden_states,)
|
||||
if up_block_add_samples is not None:
|
||||
hidden_states = hidden_states + up_block_add_samples.pop(0) # todo: add before or after
|
||||
|
||||
if return_res_samples:
|
||||
return hidden_states, output_states
|
||||
else:
|
||||
return hidden_states
|
@ -7,6 +7,7 @@ import numpy as np
|
||||
from iopaint.download import scan_models
|
||||
from iopaint.helper import switch_mps_device
|
||||
from iopaint.model import models, ControlNet, SD, SDXL
|
||||
from iopaint.model.brushnet.brushnet_wrapper import BrushNetWrapper
|
||||
from iopaint.model.utils import torch_gc, is_local_files_only
|
||||
from iopaint.schema import InpaintRequest, ModelInfo, ModelType
|
||||
|
||||
@ -22,12 +23,16 @@ class ModelManager:
|
||||
self.enable_controlnet = kwargs.get("enable_controlnet", False)
|
||||
controlnet_method = kwargs.get("controlnet_method", None)
|
||||
if (
|
||||
controlnet_method is None
|
||||
and name in self.available_models
|
||||
and self.available_models[name].support_controlnet
|
||||
controlnet_method is None
|
||||
and name in self.available_models
|
||||
and self.available_models[name].support_controlnet
|
||||
):
|
||||
controlnet_method = self.available_models[name].controlnets[0]
|
||||
self.controlnet_method = controlnet_method
|
||||
|
||||
self.enable_brushnet = kwargs.get("enable_brushnet", False)
|
||||
self.brushnet_method = kwargs.get("brushnet_method", None)
|
||||
|
||||
self.model = self.init_model(name, device, **kwargs)
|
||||
|
||||
@property
|
||||
@ -47,24 +52,30 @@ class ModelManager:
|
||||
"model_info": model_info,
|
||||
"enable_controlnet": self.enable_controlnet,
|
||||
"controlnet_method": self.controlnet_method,
|
||||
"enable_brushnet": self.enable_brushnet,
|
||||
"brushnet_method": self.brushnet_method,
|
||||
}
|
||||
|
||||
if model_info.support_controlnet and self.enable_controlnet:
|
||||
return ControlNet(device, **kwargs)
|
||||
elif model_info.name in models:
|
||||
return models[name](device, **kwargs)
|
||||
else:
|
||||
if model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SD,
|
||||
]:
|
||||
return SD(device, **kwargs)
|
||||
|
||||
if model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
]:
|
||||
return SDXL(device, **kwargs)
|
||||
if model_info.support_brushnet and self.enable_brushnet:
|
||||
return BrushNetWrapper(device, **kwargs)
|
||||
|
||||
if model_info.name in models:
|
||||
return models[name](device, **kwargs)
|
||||
|
||||
if model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
ModelType.DIFFUSERS_SD,
|
||||
]:
|
||||
return SD(device, **kwargs)
|
||||
|
||||
if model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
]:
|
||||
return SDXL(device, **kwargs)
|
||||
|
||||
raise NotImplementedError(f"Unsupported model: {name}")
|
||||
|
||||
@ -80,7 +91,10 @@ class ModelManager:
|
||||
Returns:
|
||||
BGR image
|
||||
"""
|
||||
self.switch_controlnet_method(config)
|
||||
if not config.enable_brushnet:
|
||||
self.switch_controlnet_method(config)
|
||||
if not config.enable_controlnet:
|
||||
self.switch_brushnet_method(config)
|
||||
self.enable_disable_freeu(config)
|
||||
self.enable_disable_lcm_lora(config)
|
||||
return self.model(image, mask, config).astype(np.uint8)
|
||||
@ -99,9 +113,9 @@ class ModelManager:
|
||||
self.name = new_name
|
||||
|
||||
if (
|
||||
self.available_models[new_name].support_controlnet
|
||||
and self.controlnet_method
|
||||
not in self.available_models[new_name].controlnets
|
||||
self.available_models[new_name].support_controlnet
|
||||
and self.controlnet_method
|
||||
not in self.available_models[new_name].controlnets
|
||||
):
|
||||
self.controlnet_method = self.available_models[new_name].controlnets[0]
|
||||
try:
|
||||
@ -121,14 +135,54 @@ class ModelManager:
|
||||
)
|
||||
raise e
|
||||
|
||||
def switch_brushnet_method(self, config):
|
||||
if not self.available_models[self.name].support_brushnet:
|
||||
return
|
||||
|
||||
if (
|
||||
self.enable_brushnet
|
||||
and config.brushnet_method
|
||||
and self.brushnet_method != config.brushnet_method
|
||||
):
|
||||
old_brushnet_method = self.brushnet_method
|
||||
self.brushnet_method = config.brushnet_method
|
||||
self.model.switch_brushnet_method(config.brushnet_method)
|
||||
logger.info(
|
||||
f"Switch Brushnet method from {old_brushnet_method} to {config.brushnet_method}"
|
||||
)
|
||||
|
||||
elif self.enable_brushnet != config.enable_brushnet:
|
||||
self.enable_brushnet = config.enable_brushnet
|
||||
self.brushnet_method = config.brushnet_method
|
||||
|
||||
pipe_components = {
|
||||
"vae": self.model.model.vae,
|
||||
"text_encoder": self.model.model.text_encoder,
|
||||
"unet": self.model.model.unet,
|
||||
}
|
||||
if hasattr(self.model.model, "text_encoder_2"):
|
||||
pipe_components["text_encoder_2"] = self.model.model.text_encoder_2
|
||||
|
||||
self.model = self.init_model(
|
||||
self.name,
|
||||
switch_mps_device(self.name, self.device),
|
||||
pipe_components=pipe_components,
|
||||
**self.kwargs,
|
||||
)
|
||||
|
||||
if not config.enable_brushnet:
|
||||
logger.info("BrushNet Disabled")
|
||||
else:
|
||||
logger.info("BrushNet Enabled")
|
||||
|
||||
def switch_controlnet_method(self, config):
|
||||
if not self.available_models[self.name].support_controlnet:
|
||||
return
|
||||
|
||||
if (
|
||||
self.enable_controlnet
|
||||
and config.controlnet_method
|
||||
and self.controlnet_method != config.controlnet_method
|
||||
self.enable_controlnet
|
||||
and config.controlnet_method
|
||||
and self.controlnet_method != config.controlnet_method
|
||||
):
|
||||
old_controlnet_method = self.controlnet_method
|
||||
self.controlnet_method = config.controlnet_method
|
||||
@ -155,7 +209,7 @@ class ModelManager:
|
||||
**self.kwargs,
|
||||
)
|
||||
if not config.enable_controlnet:
|
||||
logger.info(f"Disable controlnet")
|
||||
logger.info("Disable controlnet")
|
||||
else:
|
||||
logger.info(f"Enable controlnet: {config.controlnet_method}")
|
||||
|
||||
|
@ -3,6 +3,8 @@ from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional, Literal, List
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from iopaint.const import (
|
||||
INSTRUCT_PIX2PIX_NAME,
|
||||
KANDINSKY22_NAME,
|
||||
@ -11,9 +13,9 @@ from iopaint.const import (
|
||||
SDXL_CONTROLNET_CHOICES,
|
||||
SD2_CONTROLNET_CHOICES,
|
||||
SD_CONTROLNET_CHOICES,
|
||||
SD_BRUSHNET_CHOICES,
|
||||
)
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, field_validator, computed_field
|
||||
from pydantic import BaseModel, Field, computed_field, model_validator
|
||||
|
||||
|
||||
class ModelType(str, Enum):
|
||||
@ -63,6 +65,13 @@ class ModelInfo(BaseModel):
|
||||
return SD_CONTROLNET_CHOICES
|
||||
return []
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def brushnets(self) -> List[str]:
|
||||
if self.model_type in [ModelType.DIFFUSERS_SD]:
|
||||
return SD_BRUSHNET_CHOICES
|
||||
return []
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_strength(self) -> bool:
|
||||
@ -103,6 +112,13 @@ class ModelInfo(BaseModel):
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_brushnet(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
]
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def support_freeu(self) -> bool:
|
||||
@ -369,6 +385,13 @@ class InpaintRequest(BaseModel):
|
||||
"lllyasviel/control_v11p_sd15_canny", description="Controlnet method"
|
||||
)
|
||||
|
||||
# BrushNet
|
||||
enable_brushnet: bool = Field(False, description="Enable brushnet")
|
||||
brushnet_method: str = Field(
|
||||
SD_BRUSHNET_CHOICES[0], description="Brushnet method"
|
||||
)
|
||||
brushnet_conditioning_scale: float = Field(1.0, description="brushnet conditioning scale", ge=0.0, le=1.0)
|
||||
|
||||
# PowerPaint
|
||||
powerpaint_task: PowerPaintTask = Field(
|
||||
PowerPaintTask.text_guided, description="PowerPaint task"
|
||||
@ -380,31 +403,37 @@ class InpaintRequest(BaseModel):
|
||||
le=1.0,
|
||||
)
|
||||
|
||||
@field_validator("sd_seed")
|
||||
@classmethod
|
||||
def sd_seed_validator(cls, v: int) -> int:
|
||||
if v == -1:
|
||||
return random.randint(1, 99999999)
|
||||
return v
|
||||
@model_validator(mode='after')
|
||||
def validate_field(cls, values: 'InpaintRequest'):
|
||||
if values.sd_seed == -1:
|
||||
values.sd_seed = random.randint(1, 99999999)
|
||||
logger.info(f"Generate random seed: {values.sd_seed}")
|
||||
|
||||
@field_validator("controlnet_conditioning_scale")
|
||||
@classmethod
|
||||
def validate_field(cls, v: float, values):
|
||||
use_extender = values.data["use_extender"]
|
||||
enable_controlnet = values.data["enable_controlnet"]
|
||||
if use_extender and enable_controlnet:
|
||||
logger.info(f"Extender is enabled, set controlnet_conditioning_scale=0")
|
||||
return 0
|
||||
return v
|
||||
if values.use_extender and values.enable_controlnet:
|
||||
logger.info("Extender is enabled, set controlnet_conditioning_scale=0")
|
||||
values.controlnet_conditioning_scale = 0
|
||||
|
||||
@field_validator("sd_strength")
|
||||
@classmethod
|
||||
def validate_sd_strength(cls, v: float, values):
|
||||
use_extender = values.data["use_extender"]
|
||||
if use_extender:
|
||||
logger.info(f"Extender is enabled, set sd_strength=1")
|
||||
return 1.0
|
||||
return v
|
||||
if values.use_extender:
|
||||
logger.info("Extender is enabled, set sd_strength=1")
|
||||
values.sd_strength = 1.0
|
||||
|
||||
if values.enable_brushnet:
|
||||
logger.info("BrushNet is enabled, set enable_controlnet=False")
|
||||
if values.enable_controlnet:
|
||||
values.enable_controlnet = False
|
||||
if values.sd_lcm_lora:
|
||||
logger.info("BrushNet is enabled, set sd_lcm_lora=False")
|
||||
values.sd_lcm_lora = False
|
||||
if values.sd_freeu:
|
||||
logger.info("BrushNet is enabled, set sd_freeu=False")
|
||||
values.sd_freeu = False
|
||||
|
||||
if values.enable_controlnet:
|
||||
logger.info("ControlNet is enabled, set enable_brushnet=False")
|
||||
if values.enable_brushnet:
|
||||
values.enable_brushnet = False
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class RunPluginRequest(BaseModel):
|
||||
|
89
iopaint/tests/test_brushnet.py
Normal file
89
iopaint/tests/test_brushnet.py
Normal file
@ -0,0 +1,89 @@
|
||||
import os
|
||||
|
||||
from iopaint.const import SD_BRUSHNET_CHOICES
|
||||
from iopaint.tests.utils import check_device, get_config, assert_equal
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from iopaint.model_manager import ModelManager
|
||||
from iopaint.schema import HDStrategy, SDSampler, FREEUConfig
|
||||
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
save_dir = current_dir / "result"
|
||||
save_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m_karras])
|
||||
def test_runway_brushnet(device, sampler):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(
|
||||
name="runwayml/stable-diffusion-v1-5",
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=HDStrategy.ORIGINAL,
|
||||
prompt="face of a fox, sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
sd_guidance_scale=7.5,
|
||||
sd_freeu=True,
|
||||
sd_freeu_config=FREEUConfig(),
|
||||
enable_brushnet=True,
|
||||
brushnet_method=SD_BRUSHNET_CHOICES[0]
|
||||
)
|
||||
cfg.sd_sampler = sampler
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"brushnet_runway_1_5_freeu_device_{device}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m_karras])
|
||||
@pytest.mark.parametrize(
|
||||
"name",
|
||||
[
|
||||
"v1-5-pruned-emaonly.safetensors",
|
||||
],
|
||||
)
|
||||
def test_brushnet_local_file_path(device, sampler, name):
|
||||
sd_steps = check_device(device)
|
||||
model = ModelManager(
|
||||
name=name,
|
||||
device=torch.device(device),
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=False,
|
||||
)
|
||||
cfg = get_config(
|
||||
strategy=HDStrategy.ORIGINAL,
|
||||
prompt="face of a fox, sitting on a bench",
|
||||
sd_steps=sd_steps,
|
||||
sd_seed=1234,
|
||||
enable_brushnet=True,
|
||||
brushnet_method=SD_BRUSHNET_CHOICES[1]
|
||||
)
|
||||
cfg.sd_sampler = sampler
|
||||
name = f"device_{device}_{sampler}_{name}"
|
||||
|
||||
is_sdxl = "sd_xl" in name
|
||||
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"brushnet_sd_local_model_{name}.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fx=1.5 if is_sdxl else 1,
|
||||
fy=1.5 if is_sdxl else 1,
|
||||
)
|
@ -109,6 +109,84 @@ const DiffusionOptions = () => {
|
||||
)
|
||||
}
|
||||
|
||||
const renderBrushNetSetting = () => {
|
||||
if (!settings.model.support_brushnet) {
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex flex-col gap-4">
|
||||
<div className="flex justify-between items-center pr-2">
|
||||
<LabelTitle
|
||||
text="BrushNet"
|
||||
toolTip="BrushNet: A Plug-and-Play Image Inpainting Model with Decomposed Dual-Branch Diffusion"
|
||||
url="https://github.com/TencentARC/BrushNet"
|
||||
/>
|
||||
<Switch
|
||||
id="brushnet"
|
||||
checked={settings.enableBrushNet}
|
||||
onCheckedChange={(value) => {
|
||||
updateSettings({ enableBrushNet: value })
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<RowContainer>
|
||||
<Slider
|
||||
className="w-[180px]"
|
||||
defaultValue={[100]}
|
||||
min={1}
|
||||
max={100}
|
||||
step={1}
|
||||
disabled={!settings.enableBrushNet}
|
||||
value={[Math.floor(settings.brushnetConditioningScale * 100)]}
|
||||
onValueChange={(vals) =>
|
||||
updateSettings({ brushnetConditioningScale: vals[0] / 100 })
|
||||
}
|
||||
/>
|
||||
<NumberInput
|
||||
id="controlnet-weight"
|
||||
className="w-[60px] rounded-full"
|
||||
disabled={!settings.enableBrushNet}
|
||||
numberValue={settings.brushnetConditioningScale}
|
||||
allowFloat={false}
|
||||
onNumberValueChange={(val) => {
|
||||
updateSettings({ brushnetConditioningScale: val })
|
||||
}}
|
||||
/>
|
||||
</RowContainer>
|
||||
</div>
|
||||
|
||||
<div className="pr-2">
|
||||
<Select
|
||||
defaultValue={settings.brushnetMethod}
|
||||
value={settings.brushnetMethod}
|
||||
onValueChange={(value) => {
|
||||
updateSettings({ brushnetMethod: value })
|
||||
}}
|
||||
disabled={!settings.enableBrushNet}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="Select brushnet model" />
|
||||
</SelectTrigger>
|
||||
<SelectContent align="end">
|
||||
<SelectGroup>
|
||||
{Object.values(settings.model.brushnets).map((method) => (
|
||||
<SelectItem key={method} value={method}>
|
||||
{method}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
</div>
|
||||
<Separator />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const renderConterNetSetting = () => {
|
||||
if (!settings.model.support_controlnet) {
|
||||
return null
|
||||
@ -881,6 +959,7 @@ const DiffusionOptions = () => {
|
||||
{renderSeed()}
|
||||
{renderNegativePrompt()}
|
||||
<Separator />
|
||||
{renderBrushNetSetting()}
|
||||
{renderConterNetSetting()}
|
||||
{renderLCMLora()}
|
||||
{renderMaskBlur()}
|
||||
|
@ -78,6 +78,9 @@ export default async function inpaint(
|
||||
controlnet_method: settings.controlnetMethod
|
||||
? settings.controlnetMethod
|
||||
: "",
|
||||
enable_brushnet: settings.enableBrushNet,
|
||||
brushnet_method: settings.brushnetMethod ? settings.brushnetMethod : "",
|
||||
brushnet_conditioning_scale: settings.brushnetConditioningScale,
|
||||
powerpaint_task: settings.showExtender
|
||||
? PowerPaintTask.outpainting
|
||||
: settings.powerpaintTask,
|
||||
|
@ -99,6 +99,11 @@ export type Settings = {
|
||||
controlnetConditioningScale: number
|
||||
controlnetMethod: string
|
||||
|
||||
// BrushNet
|
||||
enableBrushNet: boolean
|
||||
brushnetMethod: string
|
||||
brushnetConditioningScale: number
|
||||
|
||||
enableLCMLora: boolean
|
||||
enableFreeu: boolean
|
||||
freeuConfig: FreeuConfig
|
||||
@ -306,15 +311,16 @@ const defaultValues: AppState = {
|
||||
path: "lama",
|
||||
model_type: "inpaint",
|
||||
support_controlnet: false,
|
||||
support_brushnet: false,
|
||||
support_strength: false,
|
||||
support_outpainting: false,
|
||||
controlnets: [],
|
||||
brushnets: [],
|
||||
support_freeu: false,
|
||||
support_lcm_lora: false,
|
||||
is_single_file_diffusers: false,
|
||||
need_prompt: false,
|
||||
},
|
||||
enableControlnet: false,
|
||||
showCropper: false,
|
||||
showExtender: false,
|
||||
extenderDirection: ExtenderDirection.xy,
|
||||
@ -339,8 +345,12 @@ const defaultValues: AppState = {
|
||||
sdMatchHistograms: false,
|
||||
sdScale: 1.0,
|
||||
p2pImageGuidanceScale: 1.5,
|
||||
controlnetConditioningScale: 0.4,
|
||||
enableControlnet: false,
|
||||
controlnetMethod: "lllyasviel/control_v11p_sd15_canny",
|
||||
controlnetConditioningScale: 0.4,
|
||||
enableBrushNet: false,
|
||||
brushnetMethod: "random_mask",
|
||||
brushnetConditioningScale: 1.0,
|
||||
enableLCMLora: false,
|
||||
enableFreeu: false,
|
||||
freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 },
|
||||
@ -1076,7 +1086,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
})),
|
||||
{
|
||||
name: "ZUSTAND_STATE", // name of the item in the storage (must be unique)
|
||||
version: 1,
|
||||
version: 2,
|
||||
partialize: (state) =>
|
||||
Object.fromEntries(
|
||||
Object.entries(state).filter(([key]) =>
|
||||
|
@ -48,7 +48,9 @@ export interface ModelInfo {
|
||||
support_strength: boolean
|
||||
support_outpainting: boolean
|
||||
support_controlnet: boolean
|
||||
support_brushnet: boolean
|
||||
controlnets: string[]
|
||||
brushnets: string[]
|
||||
support_freeu: boolean
|
||||
support_lcm_lora: boolean
|
||||
need_prompt: boolean
|
||||
|
Loading…
Reference in New Issue
Block a user