This commit is contained in:
Qing 2024-04-29 22:20:44 +08:00
parent 017a3d68fd
commit 80ee1b9941
11 changed files with 1548 additions and 5684 deletions

View File

@ -1,6 +1,9 @@
from itertools import chain
import PIL.Image
import cv2
import torch
from iopaint.model.original_sd_configs import get_config_files
from loguru import logger
from transformers import CLIPTextModel, CLIPTokenizer
import numpy as np
@ -14,9 +17,15 @@ from ..utils import (
handle_from_pretrained_exceptions,
)
from .powerpaint_tokenizer import task_to_prompt
from iopaint.schema import InpaintRequest
from iopaint.schema import InpaintRequest, ModelType
from .v2.BrushNet_CA import BrushNetModel
from .v2.unet_2d_condition import UNet2DConditionModel
from .v2.unet_2d_condition import UNet2DConditionModel_forward
from .v2.unet_2d_blocks import (
CrossAttnDownBlock2D_forward,
DownBlock2D_forward,
CrossAttnUpBlock2D_forward,
UpBlock2D_forward,
)
class PowerPaintV2(DiffusionInpaintModel):
@ -50,14 +59,7 @@ class PowerPaintV2(DiffusionInpaintModel):
torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"],
)
unet = handle_from_pretrained_exceptions(
UNet2DConditionModel.from_pretrained,
pretrained_model_name_or_path=self.model_id_or_path,
subfolder="unet",
variant="fp16",
torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"],
)
brushnet = BrushNetModel.from_pretrained(
self.hf_model_id,
subfolder="PowerPaint_Brushnet",
@ -65,16 +67,32 @@ class PowerPaintV2(DiffusionInpaintModel):
torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"],
)
pipe = handle_from_pretrained_exceptions(
StableDiffusionPowerPaintBrushNetPipeline.from_pretrained,
pretrained_model_name_or_path=self.model_id_or_path,
torch_dtype=torch_dtype,
unet=unet,
brushnet=brushnet,
text_encoder_brushnet=text_encoder_brushnet,
variant="fp16",
**model_kwargs,
)
if self.model_info.is_single_file_diffusers:
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
model_kwargs["num_in_channels"] = 4
else:
model_kwargs["num_in_channels"] = 9
pipe = StableDiffusionPowerPaintBrushNetPipeline.from_single_file(
self.model_id_or_path,
torch_dtype=torch_dtype,
load_safety_checker=False,
original_config_file=get_config_files()["v1"],
brushnet=brushnet,
text_encoder_brushnet=text_encoder_brushnet,
**model_kwargs,
)
else:
pipe = handle_from_pretrained_exceptions(
StableDiffusionPowerPaintBrushNetPipeline.from_pretrained,
pretrained_model_name_or_path=self.model_id_or_path,
torch_dtype=torch_dtype,
brushnet=brushnet,
text_encoder_brushnet=text_encoder_brushnet,
variant="fp16",
**model_kwargs,
)
pipe.tokenizer = PowerPaintTokenizer(
CLIPTokenizer.from_pretrained(self.hf_model_id, subfolder="tokenizer")
)
@ -95,6 +113,34 @@ class PowerPaintV2(DiffusionInpaintModel):
self.callback = kwargs.pop("callback", None)
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
self.model.unet.forward = UNet2DConditionModel_forward.__get__(
self.model.unet, self.model.unet.__class__
)
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
for down_block in chain(
self.model.unet.down_blocks, self.model.brushnet.down_blocks
):
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
down_block.forward = CrossAttnDownBlock2D_forward.__get__(
down_block, down_block.__class__
)
else:
down_block.forward = DownBlock2D_forward.__get__(
down_block, down_block.__class__
)
for up_block in chain(self.model.unet.up_blocks, self.model.brushnet.up_blocks):
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
up_block.forward = CrossAttnUpBlock2D_forward.__get__(
up_block, up_block.__class__
)
else:
up_block.forward = UpBlock2D_forward.__get__(
up_block, up_block.__class__
)
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
@ -129,11 +175,10 @@ class PowerPaintV2(DiffusionInpaintModel):
brushnet_conditioning_scale=1.0,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
callback_on_step_end=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
callback_steps=1,
).images[0]
output = (output * 255).round().astype("uint8")

View File

@ -2,6 +2,14 @@ from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from diffusers import UNet2DConditionModel
from diffusers.models.unet_2d_blocks import (
get_down_block,
get_mid_block,
get_up_block,
CrossAttnDownBlock2D,
DownBlock2D,
)
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
@ -13,18 +21,14 @@ from diffusers.models.attention_processor import (
AttnAddedKVProcessor,
AttnProcessor,
)
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, \
TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from .unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
get_down_block,
get_mid_block,
get_up_block
from diffusers.models.embeddings import (
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
TimestepEmbedding,
Timesteps,
)
from .unet_2d_condition import UNet2DConditionModel
from diffusers.models.modeling_utils import ModelMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@ -132,47 +136,55 @@ class BrushNetModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
in_channels: int = 4,
conditioning_channels: int = 5,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str, ...] = (
"UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
brushnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
addition_embed_type_num_heads: int = 64,
self,
in_channels: int = 4,
conditioning_channels: int = 5,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str, ...] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
brushnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
16,
32,
96,
256,
),
global_pool_conditions: bool = False,
addition_embed_type_num_heads: int = 64,
):
super().__init__()
@ -195,25 +207,33 @@ class BrushNetModel(ModelMixin, ConfigMixin):
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
if not isinstance(only_cross_attention, bool) and len(
only_cross_attention
) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
down_block_types
):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
transformer_layers_per_block = [transformer_layers_per_block] * len(
down_block_types
)
# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in_condition = nn.Conv2d(
in_channels + conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel,
padding=conv_in_padding
in_channels + conditioning_channels,
block_out_channels[0],
kernel_size=conv_in_kernel,
padding=conv_in_padding,
)
# time
@ -229,7 +249,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
logger.info(
"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
)
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
@ -274,7 +296,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
self.class_embedding = TimestepEmbedding(
projection_class_embeddings_input_dim, time_embed_dim
)
else:
self.class_embedding = None
@ -285,21 +309,31 @@ class BrushNetModel(ModelMixin, ConfigMixin):
text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
text_time_embedding_from_dim,
time_embed_dim,
num_heads=addition_embed_type_num_heads,
)
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
text_embed_dim=cross_attention_dim,
image_embed_dim=cross_attention_dim,
time_embed_dim=time_embed_dim,
)
elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
self.add_time_proj = Timesteps(
addition_time_embed_dim, flip_sin_to_cos, freq_shift
)
self.add_embedding = TimestepEmbedding(
projection_class_embeddings_input_dim, time_embed_dim
)
elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
raise ValueError(
f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
)
self.down_blocks = nn.ModuleList([])
self.brushnet_down_blocks = nn.ModuleList([])
@ -338,7 +372,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
attention_head_dim=attention_head_dim[i]
if attention_head_dim[i] is not None
else output_channel,
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
@ -348,12 +384,16 @@ class BrushNetModel(ModelMixin, ConfigMixin):
self.down_blocks.append(down_block)
for _ in range(layers_per_block):
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
brushnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1
)
brushnet_block = zero_module(brushnet_block)
self.brushnet_down_blocks.append(brushnet_block)
if not is_final_block:
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
brushnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1
)
brushnet_block = zero_module(brushnet_block)
self.brushnet_down_blocks.append(brushnet_block)
@ -386,7 +426,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
reversed_transformer_layers_per_block = list(
reversed(transformer_layers_per_block)
)
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
@ -399,7 +441,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
input_channel = reversed_block_out_channels[
min(i + 1, len(block_out_channels) - 1)
]
# add upsample block for all BUT final layer
if not is_final_block:
@ -427,29 +471,40 @@ class BrushNetModel(ModelMixin, ConfigMixin):
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
attention_head_dim=attention_head_dim[i]
if attention_head_dim[i] is not None
else output_channel,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
for _ in range(layers_per_block + 1):
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
brushnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1
)
brushnet_block = zero_module(brushnet_block)
self.brushnet_up_blocks.append(brushnet_block)
if not is_final_block:
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
brushnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1
)
brushnet_block = zero_module(brushnet_block)
self.brushnet_up_blocks.append(brushnet_block)
@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
brushnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
load_weights_from_unet: bool = True,
conditioning_channels: int = 5,
cls,
unet: UNet2DConditionModel,
brushnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
16,
32,
96,
256,
),
load_weights_from_unet: bool = True,
conditioning_channels: int = 5,
):
r"""
Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
@ -460,13 +515,27 @@ class BrushNetModel(ModelMixin, ConfigMixin):
where applicable.
"""
transformer_layers_per_block = (
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
unet.config.transformer_layers_per_block
if "transformer_layers_per_block" in unet.config
else 1
)
encoder_hid_dim = (
unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
)
encoder_hid_dim_type = (
unet.config.encoder_hid_dim_type
if "encoder_hid_dim_type" in unet.config
else None
)
addition_embed_type = (
unet.config.addition_embed_type
if "addition_embed_type" in unet.config
else None
)
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
addition_time_embed_dim = (
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
unet.config.addition_time_embed_dim
if "addition_time_embed_dim" in unet.config
else None
)
brushnet = cls(
@ -475,14 +544,21 @@ class BrushNetModel(ModelMixin, ConfigMixin):
flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift,
# down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
down_block_types=["CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D", ],
down_block_types=[
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
],
# mid_block_type='MidBlock2D',
mid_block_type="UNetMidBlock2DCrossAttn",
# up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
up_block_types=[
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
],
only_cross_attention=unet.config.only_cross_attention,
block_out_channels=unet.config.block_out_channels,
layers_per_block=unet.config.layers_per_block,
@ -510,21 +586,33 @@ class BrushNetModel(ModelMixin, ConfigMixin):
)
if load_weights_from_unet:
conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
conv_in_condition_weight = torch.zeros_like(
brushnet.conv_in_condition.weight
)
conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
brushnet.conv_in_condition.weight = torch.nn.Parameter(
conv_in_condition_weight
)
brushnet.conv_in_condition.bias = unet.conv_in.bias
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
if brushnet.class_embedding:
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
brushnet.class_embedding.load_state_dict(
unet.class_embedding.state_dict()
)
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
brushnet.down_blocks.load_state_dict(
unet.down_blocks.state_dict(), strict=False
)
brushnet.mid_block.load_state_dict(
unet.mid_block.state_dict(), strict=False
)
brushnet.up_blocks.load_state_dict(
unet.up_blocks.state_dict(), strict=False
)
return brushnet.to(unet.dtype)
@ -539,9 +627,15 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
def fn_recursive_add_processors(
name: str,
module: torch.nn.Module,
processors: Dict[str, AttentionProcessor],
):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
processors[f"{name}.processor"] = module.get_processor(
return_deprecated_lora=True
)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@ -554,7 +648,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
):
r"""
Sets the attention processor to use to compute attention.
@ -593,9 +689,15 @@ class BrushNetModel(ModelMixin, ConfigMixin):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
if all(
proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
for proc in self.attn_processors.values()
):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
elif all(
proc.__class__ in CROSS_ATTENTION_PROCESSORS
for proc in self.attn_processors.values()
):
processor = AttnProcessor()
else:
raise ValueError(
@ -642,7 +744,11 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# make smallest slice possible
slice_size = num_sliceable_layers * [1]
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
slice_size = (
num_sliceable_layers * [slice_size]
if not isinstance(slice_size, list)
else slice_size
)
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
@ -659,7 +765,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
def fn_recursive_set_attention_slice(
module: torch.nn.Module, slice_size: List[int]
):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
@ -675,19 +783,19 @@ class BrushNetModel(ModelMixin, ConfigMixin):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
brushnet_cond: torch.FloatTensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True,
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
brushnet_cond: torch.FloatTensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
"""
The [`BrushNetModel`] forward method.
@ -737,7 +845,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
elif channel_order == "bgr":
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
else:
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
raise ValueError(
f"unknown `brushnet_conditioning_channel_order`: {channel_order}"
)
# prepare attention_mask
if attention_mask is not None:
@ -773,7 +883,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
raise ValueError(
"class_labels should be provided when num_class_embeds > 0"
)
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
@ -812,7 +924,10 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
if (
hasattr(downsample_block, "has_cross_attention")
and downsample_block.has_cross_attention
):
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
@ -827,13 +942,20 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# 4. PaintingNet down blocks
brushnet_down_block_res_samples = ()
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
for down_block_res_sample, brushnet_down_block in zip(
down_block_res_samples, self.brushnet_down_blocks
):
down_block_res_sample = brushnet_down_block(down_block_res_sample)
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (
down_block_res_sample,
)
# 5. mid
if self.mid_block is not None:
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
if (
hasattr(self.mid_block, "has_cross_attention")
and self.mid_block.has_cross_attention
):
sample = self.mid_block(
sample,
emb,
@ -852,15 +974,20 @@ class BrushNetModel(ModelMixin, ConfigMixin):
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[
: -len(upsample_block.resnets)
]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
if (
hasattr(upsample_block, "has_cross_attention")
and upsample_block.has_cross_attention
):
sample, up_res_samples = upsample_block(
hidden_states=sample,
temb=emb,
@ -869,7 +996,7 @@ class BrushNetModel(ModelMixin, ConfigMixin):
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=attention_mask,
return_res_samples=True
return_res_samples=True,
)
else:
sample, up_res_samples = upsample_block(
@ -877,53 +1004,87 @@ class BrushNetModel(ModelMixin, ConfigMixin):
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
return_res_samples=True
return_res_samples=True,
)
up_block_res_samples += up_res_samples
# 8. BrushNet up blocks
brushnet_up_block_res_samples = ()
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
for up_block_res_sample, brushnet_up_block in zip(
up_block_res_samples, self.brushnet_up_blocks
):
up_block_res_sample = brushnet_up_block(up_block_res_sample)
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (
up_block_res_sample,
)
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0,
len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
device=sample.device) # 0.1 to 1.0
scales = torch.logspace(
-1,
0,
len(brushnet_down_block_res_samples)
+ 1
+ len(brushnet_up_block_res_samples),
device=sample.device,
) # 0.1 to 1.0
scales = scales * conditioning_scale
brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples,
scales[:len(
brushnet_down_block_res_samples)])]
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples,
scales[
len(brushnet_down_block_res_samples) + 1:])]
brushnet_down_block_res_samples = [
sample * scale
for sample, scale in zip(
brushnet_down_block_res_samples,
scales[: len(brushnet_down_block_res_samples)],
)
]
brushnet_mid_block_res_sample = (
brushnet_mid_block_res_sample
* scales[len(brushnet_down_block_res_samples)]
)
brushnet_up_block_res_samples = [
sample * scale
for sample, scale in zip(
brushnet_up_block_res_samples,
scales[len(brushnet_down_block_res_samples) + 1 :],
)
]
else:
brushnet_down_block_res_samples = [sample * conditioning_scale for sample in
brushnet_down_block_res_samples]
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
brushnet_down_block_res_samples = [
sample * conditioning_scale
for sample in brushnet_down_block_res_samples
]
brushnet_mid_block_res_sample = (
brushnet_mid_block_res_sample * conditioning_scale
)
brushnet_up_block_res_samples = [
sample * conditioning_scale for sample in brushnet_up_block_res_samples
]
if self.config.global_pool_conditions:
brushnet_down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
torch.mean(sample, dim=(2, 3), keepdim=True)
for sample in brushnet_down_block_res_samples
]
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
brushnet_mid_block_res_sample = torch.mean(
brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True
)
brushnet_up_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
torch.mean(sample, dim=(2, 3), keepdim=True)
for sample in brushnet_up_block_res_samples
]
if not return_dict:
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
return (
brushnet_down_block_res_samples,
brushnet_mid_block_res_sample,
brushnet_up_block_res_samples,
)
return BrushNetOutput(
down_block_res_samples=brushnet_down_block_res_samples,
mid_block_res_sample=brushnet_mid_block_res_sample,
up_block_res_samples=brushnet_up_block_res_samples
up_block_res_samples=brushnet_up_block_res_samples,
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -122,9 +122,13 @@ class ModelInfo(BaseModel):
@computed_field
@property
def support_powerpaint_v2(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
]
return (
self.model_type
in [
ModelType.DIFFUSERS_SD,
]
and self.name != POWERPAINT_NAME
)
class Choices(str, Enum):
@ -215,7 +219,6 @@ class SDSampler(str, Enum):
lcm = "LCM"
class PowerPaintTask(Choices):
text_guided = "text-guided"
context_aware = "context-aware"

View File

@ -59,6 +59,9 @@ const DiffusionOptions = () => {
updateExtenderDirection,
adjustMask,
clearMask,
updateEnablePowerPaintV2,
updateEnableBrushNet,
updateEnableControlnet,
] = useStore((state) => [
state.serverConfig.samplers,
state.settings,
@ -71,6 +74,9 @@ const DiffusionOptions = () => {
state.updateExtenderDirection,
state.adjustMask,
state.clearMask,
state.updateEnablePowerPaintV2,
state.updateEnableBrushNet,
state.updateEnableControlnet,
])
const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile)
const negativePromptRef = useRef(null)
@ -114,12 +120,8 @@ const DiffusionOptions = () => {
return null
}
let disable = settings.enableControlnet
let toolTip =
"BrushNet is a plug-and-play image inpainting model with decomposed dual-branch diffusion. It can be used to inpaint images by conditioning on a mask."
if (disable) {
toolTip = "ControlNet is enabled, BrushNet is disabled."
}
"BrushNet is a plug-and-play image inpainting model works on any SD1.5 base models."
return (
<div className="flex flex-col gap-4">
@ -129,20 +131,19 @@ const DiffusionOptions = () => {
text="BrushNet"
url="https://github.com/TencentARC/BrushNet"
toolTip={toolTip}
disabled={disable}
/>
<Switch
id="brushnet"
checked={settings.enableBrushNet}
onCheckedChange={(value) => {
updateSettings({ enableBrushNet: value })
updateEnableBrushNet(value)
}}
disabled={disable}
/>
</RowContainer>
<RowContainer>
{/* <RowContainer>
<Slider
defaultValue={[100]}
className="w-[180px]"
min={1}
max={100}
step={1}
@ -155,14 +156,13 @@ const DiffusionOptions = () => {
<NumberInput
id="brushnet-weight"
className="w-[60px] rounded-full"
disabled={!settings.enableBrushNet || disable}
numberValue={settings.brushnetConditioningScale}
allowFloat={false}
onNumberValueChange={(val) => {
updateSettings({ brushnetConditioningScale: val })
}}
/>
</RowContainer>
</RowContainer> */}
<RowContainer>
<Select
@ -198,12 +198,8 @@ const DiffusionOptions = () => {
return null
}
let disable = settings.enableBrushNet
let toolTip =
"Using an additional conditioning image to control how an image is generated"
if (disable) {
toolTip = "BrushNet is enabled, ControlNet is disabled."
}
return (
<div className="flex flex-col gap-4">
@ -213,15 +209,13 @@ const DiffusionOptions = () => {
text="ControlNet"
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inpaint#controlnet"
toolTip={toolTip}
disabled={disable}
/>
<Switch
id="controlnet"
checked={settings.enableControlnet}
onCheckedChange={(value) => {
updateSettings({ enableControlnet: value })
updateEnableControlnet(value)
}}
disabled={disable}
/>
</RowContainer>
@ -233,7 +227,7 @@ const DiffusionOptions = () => {
min={1}
max={100}
step={1}
disabled={!settings.enableControlnet || disable}
disabled={!settings.enableControlnet}
value={[Math.floor(settings.controlnetConditioningScale * 100)]}
onValueChange={(vals) =>
updateSettings({ controlnetConditioningScale: vals[0] / 100 })
@ -242,7 +236,7 @@ const DiffusionOptions = () => {
<NumberInput
id="controlnet-weight"
className="w-[60px] rounded-full"
disabled={!settings.enableControlnet || disable}
disabled={!settings.enableControlnet}
numberValue={settings.controlnetConditioningScale}
allowFloat={false}
onNumberValueChange={(val) => {
@ -286,12 +280,8 @@ const DiffusionOptions = () => {
return null
}
let disable = settings.enableBrushNet
let toolTip =
"Enable quality image generation in typically 2-4 steps. Suggest disabling guidance_scale by setting it to 0. You can also try values between 1.0 and 2.0. When LCM Lora is enabled, LCMSampler will be used automatically."
if (disable) {
toolTip = "BrushNet is enabled, LCM Lora is disabled."
}
"Enable quality image generation in typically 2-8 steps. Suggest disabling guidance_scale by setting it to 0. You can also try values between 1.0 and 2.0. When LCM Lora is enabled, LCMSampler will be used automatically."
return (
<>
@ -300,7 +290,6 @@ const DiffusionOptions = () => {
text="LCM LoRA"
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm_lora"
toolTip={toolTip}
disabled={disable}
/>
<Switch
id="lcm-lora"
@ -308,7 +297,6 @@ const DiffusionOptions = () => {
onCheckedChange={(value) => {
updateSettings({ enableLCMLora: value })
}}
disabled={disable}
/>
</RowContainer>
<Separator />
@ -561,10 +549,6 @@ const DiffusionOptions = () => {
}
const renderPowerPaintTaskType = () => {
if (settings.model.name !== POWERPAINT) {
return null
}
return (
<RowContainer>
<LabelTitle
@ -579,7 +563,7 @@ const DiffusionOptions = () => {
}}
disabled={settings.showExtender}
>
<SelectTrigger className="w-[140px]">
<SelectTrigger className="w-[130px]">
<SelectValue placeholder="Select task" />
</SelectTrigger>
<SelectContent align="end">
@ -587,6 +571,7 @@ const DiffusionOptions = () => {
{[
PowerPaintTask.text_guided,
PowerPaintTask.object_remove,
PowerPaintTask.context_aware,
PowerPaintTask.shape_guided,
].map((task) => (
<SelectItem key={task} value={task}>
@ -600,6 +585,44 @@ const DiffusionOptions = () => {
)
}
const renderPowerPaintV1 = () => {
if (settings.model.name !== POWERPAINT) {
return null
}
return (
<>
{renderPowerPaintTaskType()}
<Separator />
</>
)
}
const renderPowerPaintV2 = () => {
if (settings.model.support_powerpaint_v2 === false) {
return null
}
return (
<>
<RowContainer>
<LabelTitle
text="PowerPaint V2"
toolTip="PowerPaint is a plug-and-play image inpainting model works on any SD1.5 base models."
/>
<Switch
id="powerpaint-v2"
checked={settings.enablePowerPaintV2}
onCheckedChange={(value) => {
updateEnablePowerPaintV2(value)
}}
/>
</RowContainer>
{renderPowerPaintTaskType()}
<Separator />
</>
)
}
const renderSteps = () => {
return (
<RowContainer>
@ -868,7 +891,7 @@ const DiffusionOptions = () => {
{renderMaskBlur()}
{renderMaskAdjuster()}
{renderMatchHistograms()}
{renderPowerPaintTaskType()}
{renderPowerPaintV1()}
{renderSteps()}
{renderGuidanceScale()}
{renderP2PImageGuidanceScale()}
@ -878,6 +901,7 @@ const DiffusionOptions = () => {
{renderNegativePrompt()}
<Separator />
{renderBrushNetSetting()}
{renderPowerPaintV2()}
{renderConterNetSetting()}
{renderLCMLora()}
{renderPaintByExample()}

View File

@ -22,7 +22,7 @@ const SelectTrigger = React.forwardRef<
<SelectPrimitive.Trigger
ref={ref}
className={cn(
"flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent px-3 py-2 text-sm shadow-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-1 focus:ring-ring disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1",
"flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent pl-2 pr-1 py-2 text-sm shadow-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-1 focus:ring-ring disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1",
className
)}
tabIndex={-1}

View File

@ -79,6 +79,7 @@ export default async function inpaint(
enable_brushnet: settings.enableBrushNet,
brushnet_method: settings.brushnetMethod ? settings.brushnetMethod : "",
brushnet_conditioning_scale: settings.brushnetConditioningScale,
enable_powerpaint_v2: settings.enablePowerPaintV2,
powerpaint_task: settings.showExtender
? PowerPaintTask.outpainting
: settings.powerpaintTask,

View File

@ -106,6 +106,7 @@ export type Settings = {
enableLCMLora: boolean
// PowerPaint
enablePowerPaintV2: boolean
powerpaintTask: PowerPaintTask
// AdjustMask
@ -194,6 +195,12 @@ type AppAction = {
setServerConfig: (newValue: ServerConfig) => void
setSeed: (newValue: number) => void
updateSettings: (newSettings: Partial<Settings>) => void
// 互斥
updateEnablePowerPaintV2: (newValue: boolean) => void
updateEnableBrushNet: (newValue: boolean) => void
updateEnableControlnet: (newValue: boolean) => void
setModel: (newModel: ModelInfo) => void
updateFileManagerState: (newState: Partial<FileManagerState>) => void
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
@ -311,6 +318,7 @@ const defaultValues: AppState = {
support_brushnet: false,
support_strength: false,
support_outpainting: false,
support_powerpaint_v2: false,
controlnets: [],
brushnets: [],
support_lcm_lora: false,
@ -425,6 +433,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
if (
get().settings.model.support_outpainting &&
settings.showExtender &&
extenderState.x === 0 &&
extenderState.y === 0 &&
extenderState.height === imageHeight &&
extenderState.width === imageWidth
) {
@ -798,6 +808,38 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
})
},
updateEnablePowerPaintV2: (newValue: boolean) => {
get().updateSettings({ enablePowerPaintV2: newValue })
if (newValue) {
get().updateSettings({
enableBrushNet: false,
enableControlnet: false,
enableLCMLora: false,
})
}
},
updateEnableBrushNet: (newValue: boolean) => {
get().updateSettings({ enableBrushNet: newValue })
if (newValue) {
get().updateSettings({
enablePowerPaintV2: false,
enableControlnet: false,
enableLCMLora: false,
})
}
},
updateEnableControlnet(newValue) {
get().updateSettings({ enableControlnet: newValue })
if (newValue) {
get().updateSettings({
enablePowerPaintV2: false,
enableBrushNet: false,
})
}
},
setModel: (newModel: ModelInfo) => {
set((state) => {
state.settings.model = newModel

View File

@ -49,6 +49,7 @@ export interface ModelInfo {
support_outpainting: boolean
support_controlnet: boolean
support_brushnet: boolean
support_powerpaint_v2: boolean
controlnets: string[]
brushnets: string[]
support_lcm_lora: boolean
@ -123,6 +124,7 @@ export enum ExtenderDirection {
export enum PowerPaintTask {
text_guided = "text-guided",
shape_guided = "shape-guided",
context_aware = "context-aware",
object_remove = "object-remove",
outpainting = "outpainting",
}