This commit is contained in:
Qing 2024-04-29 22:20:44 +08:00
parent 017a3d68fd
commit 80ee1b9941
11 changed files with 1548 additions and 5684 deletions

View File

@ -1,6 +1,9 @@
from itertools import chain
import PIL.Image import PIL.Image
import cv2 import cv2
import torch import torch
from iopaint.model.original_sd_configs import get_config_files
from loguru import logger from loguru import logger
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
import numpy as np import numpy as np
@ -14,9 +17,15 @@ from ..utils import (
handle_from_pretrained_exceptions, handle_from_pretrained_exceptions,
) )
from .powerpaint_tokenizer import task_to_prompt from .powerpaint_tokenizer import task_to_prompt
from iopaint.schema import InpaintRequest from iopaint.schema import InpaintRequest, ModelType
from .v2.BrushNet_CA import BrushNetModel from .v2.BrushNet_CA import BrushNetModel
from .v2.unet_2d_condition import UNet2DConditionModel from .v2.unet_2d_condition import UNet2DConditionModel_forward
from .v2.unet_2d_blocks import (
CrossAttnDownBlock2D_forward,
DownBlock2D_forward,
CrossAttnUpBlock2D_forward,
UpBlock2D_forward,
)
class PowerPaintV2(DiffusionInpaintModel): class PowerPaintV2(DiffusionInpaintModel):
@ -50,14 +59,7 @@ class PowerPaintV2(DiffusionInpaintModel):
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"], local_files_only=model_kwargs["local_files_only"],
) )
unet = handle_from_pretrained_exceptions(
UNet2DConditionModel.from_pretrained,
pretrained_model_name_or_path=self.model_id_or_path,
subfolder="unet",
variant="fp16",
torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"],
)
brushnet = BrushNetModel.from_pretrained( brushnet = BrushNetModel.from_pretrained(
self.hf_model_id, self.hf_model_id,
subfolder="PowerPaint_Brushnet", subfolder="PowerPaint_Brushnet",
@ -65,16 +67,32 @@ class PowerPaintV2(DiffusionInpaintModel):
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"], local_files_only=model_kwargs["local_files_only"],
) )
pipe = handle_from_pretrained_exceptions(
StableDiffusionPowerPaintBrushNetPipeline.from_pretrained, if self.model_info.is_single_file_diffusers:
pretrained_model_name_or_path=self.model_id_or_path, if self.model_info.model_type == ModelType.DIFFUSERS_SD:
torch_dtype=torch_dtype, model_kwargs["num_in_channels"] = 4
unet=unet, else:
brushnet=brushnet, model_kwargs["num_in_channels"] = 9
text_encoder_brushnet=text_encoder_brushnet,
variant="fp16", pipe = StableDiffusionPowerPaintBrushNetPipeline.from_single_file(
**model_kwargs, self.model_id_or_path,
) torch_dtype=torch_dtype,
load_safety_checker=False,
original_config_file=get_config_files()["v1"],
brushnet=brushnet,
text_encoder_brushnet=text_encoder_brushnet,
**model_kwargs,
)
else:
pipe = handle_from_pretrained_exceptions(
StableDiffusionPowerPaintBrushNetPipeline.from_pretrained,
pretrained_model_name_or_path=self.model_id_or_path,
torch_dtype=torch_dtype,
brushnet=brushnet,
text_encoder_brushnet=text_encoder_brushnet,
variant="fp16",
**model_kwargs,
)
pipe.tokenizer = PowerPaintTokenizer( pipe.tokenizer = PowerPaintTokenizer(
CLIPTokenizer.from_pretrained(self.hf_model_id, subfolder="tokenizer") CLIPTokenizer.from_pretrained(self.hf_model_id, subfolder="tokenizer")
) )
@ -95,6 +113,34 @@ class PowerPaintV2(DiffusionInpaintModel):
self.callback = kwargs.pop("callback", None) self.callback = kwargs.pop("callback", None)
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
self.model.unet.forward = UNet2DConditionModel_forward.__get__(
self.model.unet, self.model.unet.__class__
)
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
for down_block in chain(
self.model.unet.down_blocks, self.model.brushnet.down_blocks
):
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
down_block.forward = CrossAttnDownBlock2D_forward.__get__(
down_block, down_block.__class__
)
else:
down_block.forward = DownBlock2D_forward.__get__(
down_block, down_block.__class__
)
for up_block in chain(self.model.unet.up_blocks, self.model.brushnet.up_blocks):
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
up_block.forward = CrossAttnUpBlock2D_forward.__get__(
up_block, up_block.__class__
)
else:
up_block.forward = UpBlock2D_forward.__get__(
up_block, up_block.__class__
)
def forward(self, image, mask, config: InpaintRequest): def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size """Input image and output image have same size
image: [H, W, C] RGB image: [H, W, C] RGB
@ -129,11 +175,10 @@ class PowerPaintV2(DiffusionInpaintModel):
brushnet_conditioning_scale=1.0, brushnet_conditioning_scale=1.0,
guidance_scale=config.sd_guidance_scale, guidance_scale=config.sd_guidance_scale,
output_type="np", output_type="np",
callback=self.callback, callback_on_step_end=self.callback,
height=img_h, height=img_h,
width=img_w, width=img_w,
generator=torch.manual_seed(config.sd_seed), generator=torch.manual_seed(config.sd_seed),
callback_steps=1,
).images[0] ).images[0]
output = (output * 255).round().astype("uint8") output = (output * 255).round().astype("uint8")

View File

@ -2,6 +2,14 @@ from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from diffusers import UNet2DConditionModel
from diffusers.models.unet_2d_blocks import (
get_down_block,
get_mid_block,
get_up_block,
CrossAttnDownBlock2D,
DownBlock2D,
)
from torch import nn from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
@ -13,18 +21,14 @@ from diffusers.models.attention_processor import (
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
) )
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, \ from diffusers.models.embeddings import (
TimestepEmbedding, Timesteps TextImageProjection,
from diffusers.models.modeling_utils import ModelMixin TextImageTimeEmbedding,
from .unet_2d_blocks import ( TextTimeEmbedding,
CrossAttnDownBlock2D, TimestepEmbedding,
DownBlock2D, Timesteps,
get_down_block,
get_mid_block,
get_up_block
) )
from diffusers.models.modeling_utils import ModelMixin
from .unet_2d_condition import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@ -132,47 +136,55 @@ class BrushNetModel(ModelMixin, ConfigMixin):
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
in_channels: int = 4, in_channels: int = 4,
conditioning_channels: int = 5, conditioning_channels: int = 5,
flip_sin_to_cos: bool = True, flip_sin_to_cos: bool = True,
freq_shift: int = 0, freq_shift: int = 0,
down_block_types: Tuple[str, ...] = ( down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"DownBlock2D", "DownBlock2D",
), ),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str, ...] = ( up_block_types: Tuple[str, ...] = (
"UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D" "UpBlock2D",
), "CrossAttnUpBlock2D",
only_cross_attention: Union[bool, Tuple[bool]] = False, "CrossAttnUpBlock2D",
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), "CrossAttnUpBlock2D",
layers_per_block: int = 2, ),
downsample_padding: int = 1, only_cross_attention: Union[bool, Tuple[bool]] = False,
mid_block_scale_factor: float = 1, block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
act_fn: str = "silu", layers_per_block: int = 2,
norm_num_groups: Optional[int] = 32, downsample_padding: int = 1,
norm_eps: float = 1e-5, mid_block_scale_factor: float = 1,
cross_attention_dim: int = 1280, act_fn: str = "silu",
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, norm_num_groups: Optional[int] = 32,
encoder_hid_dim: Optional[int] = None, norm_eps: float = 1e-5,
encoder_hid_dim_type: Optional[str] = None, cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int, ...]] = 8, transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, encoder_hid_dim: Optional[int] = None,
use_linear_projection: bool = False, encoder_hid_dim_type: Optional[str] = None,
class_embed_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int, ...]] = 8,
addition_embed_type: Optional[str] = None, num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
addition_time_embed_dim: Optional[int] = None, use_linear_projection: bool = False,
num_class_embeds: Optional[int] = None, class_embed_type: Optional[str] = None,
upcast_attention: bool = False, addition_embed_type: Optional[str] = None,
resnet_time_scale_shift: str = "default", addition_time_embed_dim: Optional[int] = None,
projection_class_embeddings_input_dim: Optional[int] = None, num_class_embeds: Optional[int] = None,
brushnet_conditioning_channel_order: str = "rgb", upcast_attention: bool = False,
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), resnet_time_scale_shift: str = "default",
global_pool_conditions: bool = False, projection_class_embeddings_input_dim: Optional[int] = None,
addition_embed_type_num_heads: int = 64, brushnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
16,
32,
96,
256,
),
global_pool_conditions: bool = False,
addition_embed_type_num_heads: int = 64,
): ):
super().__init__() super().__init__()
@ -195,25 +207,33 @@ class BrushNetModel(ModelMixin, ConfigMixin):
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
) )
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): if not isinstance(only_cross_attention, bool) and len(
only_cross_attention
) != len(down_block_types):
raise ValueError( raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
) )
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
down_block_types
):
raise ValueError( raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
) )
if isinstance(transformer_layers_per_block, int): if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) transformer_layers_per_block = [transformer_layers_per_block] * len(
down_block_types
)
# input # input
conv_in_kernel = 3 conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2 conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in_condition = nn.Conv2d( self.conv_in_condition = nn.Conv2d(
in_channels + conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, in_channels + conditioning_channels,
padding=conv_in_padding block_out_channels[0],
kernel_size=conv_in_kernel,
padding=conv_in_padding,
) )
# time # time
@ -229,7 +249,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
if encoder_hid_dim_type is None and encoder_hid_dim is not None: if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj" encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") logger.info(
"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
)
if encoder_hid_dim is None and encoder_hid_dim_type is not None: if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError( raise ValueError(
@ -274,7 +296,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors. # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) self.class_embedding = TimestepEmbedding(
projection_class_embeddings_input_dim, time_embed_dim
)
else: else:
self.class_embedding = None self.class_embedding = None
@ -285,21 +309,31 @@ class BrushNetModel(ModelMixin, ConfigMixin):
text_time_embedding_from_dim = cross_attention_dim text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding( self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads text_time_embedding_from_dim,
time_embed_dim,
num_heads=addition_embed_type_num_heads,
) )
elif addition_embed_type == "text_image": elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding( self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim text_embed_dim=cross_attention_dim,
image_embed_dim=cross_attention_dim,
time_embed_dim=time_embed_dim,
) )
elif addition_embed_type == "text_time": elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_time_proj = Timesteps(
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) addition_time_embed_dim, flip_sin_to_cos, freq_shift
)
self.add_embedding = TimestepEmbedding(
projection_class_embeddings_input_dim, time_embed_dim
)
elif addition_embed_type is not None: elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") raise ValueError(
f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
)
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
self.brushnet_down_blocks = nn.ModuleList([]) self.brushnet_down_blocks = nn.ModuleList([])
@ -338,7 +372,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[i], num_attention_heads=num_attention_heads[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, attention_head_dim=attention_head_dim[i]
if attention_head_dim[i] is not None
else output_channel,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
@ -348,12 +384,16 @@ class BrushNetModel(ModelMixin, ConfigMixin):
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
for _ in range(layers_per_block): for _ in range(layers_per_block):
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) brushnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1
)
brushnet_block = zero_module(brushnet_block) brushnet_block = zero_module(brushnet_block)
self.brushnet_down_blocks.append(brushnet_block) self.brushnet_down_blocks.append(brushnet_block)
if not is_final_block: if not is_final_block:
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) brushnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1
)
brushnet_block = zero_module(brushnet_block) brushnet_block = zero_module(brushnet_block)
self.brushnet_down_blocks.append(brushnet_block) self.brushnet_down_blocks.append(brushnet_block)
@ -386,7 +426,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# up # up
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block))) reversed_transformer_layers_per_block = list(
reversed(transformer_layers_per_block)
)
only_cross_attention = list(reversed(only_cross_attention)) only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
@ -399,7 +441,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
prev_output_channel = output_channel prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i] output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] input_channel = reversed_block_out_channels[
min(i + 1, len(block_out_channels) - 1)
]
# add upsample block for all BUT final layer # add upsample block for all BUT final layer
if not is_final_block: if not is_final_block:
@ -427,29 +471,40 @@ class BrushNetModel(ModelMixin, ConfigMixin):
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, attention_head_dim=attention_head_dim[i]
if attention_head_dim[i] is not None
else output_channel,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
for _ in range(layers_per_block + 1): for _ in range(layers_per_block + 1):
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) brushnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1
)
brushnet_block = zero_module(brushnet_block) brushnet_block = zero_module(brushnet_block)
self.brushnet_up_blocks.append(brushnet_block) self.brushnet_up_blocks.append(brushnet_block)
if not is_final_block: if not is_final_block:
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) brushnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1
)
brushnet_block = zero_module(brushnet_block) brushnet_block = zero_module(brushnet_block)
self.brushnet_up_blocks.append(brushnet_block) self.brushnet_up_blocks.append(brushnet_block)
@classmethod @classmethod
def from_unet( def from_unet(
cls, cls,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
brushnet_conditioning_channel_order: str = "rgb", brushnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
load_weights_from_unet: bool = True, 16,
conditioning_channels: int = 5, 32,
96,
256,
),
load_weights_from_unet: bool = True,
conditioning_channels: int = 5,
): ):
r""" r"""
Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`]. Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
@ -460,13 +515,27 @@ class BrushNetModel(ModelMixin, ConfigMixin):
where applicable. where applicable.
""" """
transformer_layers_per_block = ( transformer_layers_per_block = (
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 unet.config.transformer_layers_per_block
if "transformer_layers_per_block" in unet.config
else 1
)
encoder_hid_dim = (
unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
)
encoder_hid_dim_type = (
unet.config.encoder_hid_dim_type
if "encoder_hid_dim_type" in unet.config
else None
)
addition_embed_type = (
unet.config.addition_embed_type
if "addition_embed_type" in unet.config
else None
) )
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
addition_time_embed_dim = ( addition_time_embed_dim = (
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None unet.config.addition_time_embed_dim
if "addition_time_embed_dim" in unet.config
else None
) )
brushnet = cls( brushnet = cls(
@ -475,14 +544,21 @@ class BrushNetModel(ModelMixin, ConfigMixin):
flip_sin_to_cos=unet.config.flip_sin_to_cos, flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift, freq_shift=unet.config.freq_shift,
# down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'], # down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
down_block_types=["CrossAttnDownBlock2D", down_block_types=[
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"DownBlock2D", ], "CrossAttnDownBlock2D",
"DownBlock2D",
],
# mid_block_type='MidBlock2D', # mid_block_type='MidBlock2D',
mid_block_type="UNetMidBlock2DCrossAttn", mid_block_type="UNetMidBlock2DCrossAttn",
# up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'], # up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], up_block_types=[
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
],
only_cross_attention=unet.config.only_cross_attention, only_cross_attention=unet.config.only_cross_attention,
block_out_channels=unet.config.block_out_channels, block_out_channels=unet.config.block_out_channels,
layers_per_block=unet.config.layers_per_block, layers_per_block=unet.config.layers_per_block,
@ -510,21 +586,33 @@ class BrushNetModel(ModelMixin, ConfigMixin):
) )
if load_weights_from_unet: if load_weights_from_unet:
conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight) conv_in_condition_weight = torch.zeros_like(
brushnet.conv_in_condition.weight
)
conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight) brushnet.conv_in_condition.weight = torch.nn.Parameter(
conv_in_condition_weight
)
brushnet.conv_in_condition.bias = unet.conv_in.bias brushnet.conv_in_condition.bias = unet.conv_in.bias
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict()) brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
if brushnet.class_embedding: if brushnet.class_embedding:
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) brushnet.class_embedding.load_state_dict(
unet.class_embedding.state_dict()
)
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False) brushnet.down_blocks.load_state_dict(
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False) unet.down_blocks.state_dict(), strict=False
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False) )
brushnet.mid_block.load_state_dict(
unet.mid_block.state_dict(), strict=False
)
brushnet.up_blocks.load_state_dict(
unet.up_blocks.state_dict(), strict=False
)
return brushnet.to(unet.dtype) return brushnet.to(unet.dtype)
@ -539,9 +627,15 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# set recursively # set recursively
processors = {} processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(
name: str,
module: torch.nn.Module,
processors: Dict[str, AttentionProcessor],
):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor(
return_deprecated_lora=True
)
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@ -554,7 +648,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
return processors return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
@ -593,9 +689,15 @@ class BrushNetModel(ModelMixin, ConfigMixin):
""" """
Disables custom attention processors and sets the default attention implementation. Disables custom attention processors and sets the default attention implementation.
""" """
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): if all(
proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
for proc in self.attn_processors.values()
):
processor = AttnAddedKVProcessor() processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): elif all(
proc.__class__ in CROSS_ATTENTION_PROCESSORS
for proc in self.attn_processors.values()
):
processor = AttnProcessor() processor = AttnProcessor()
else: else:
raise ValueError( raise ValueError(
@ -642,7 +744,11 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# make smallest slice possible # make smallest slice possible
slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [1]
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size slice_size = (
num_sliceable_layers * [slice_size]
if not isinstance(slice_size, list)
else slice_size
)
if len(slice_size) != len(sliceable_head_dims): if len(slice_size) != len(sliceable_head_dims):
raise ValueError( raise ValueError(
@ -659,7 +765,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# Recursively walk through all the children. # Recursively walk through all the children.
# Any children which exposes the set_attention_slice method # Any children which exposes the set_attention_slice method
# gets the message # gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): def fn_recursive_set_attention_slice(
module: torch.nn.Module, slice_size: List[int]
):
if hasattr(module, "set_attention_slice"): if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop()) module.set_attention_slice(slice_size.pop())
@ -675,19 +783,19 @@ class BrushNetModel(ModelMixin, ConfigMixin):
module.gradient_checkpointing = value module.gradient_checkpointing = value
def forward( def forward(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int], timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
brushnet_cond: torch.FloatTensor, brushnet_cond: torch.FloatTensor,
conditioning_scale: float = 1.0, conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None, class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False, guess_mode: bool = False,
return_dict: bool = True, return_dict: bool = True,
) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]: ) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
""" """
The [`BrushNetModel`] forward method. The [`BrushNetModel`] forward method.
@ -737,7 +845,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
elif channel_order == "bgr": elif channel_order == "bgr":
brushnet_cond = torch.flip(brushnet_cond, dims=[1]) brushnet_cond = torch.flip(brushnet_cond, dims=[1])
else: else:
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}") raise ValueError(
f"unknown `brushnet_conditioning_channel_order`: {channel_order}"
)
# prepare attention_mask # prepare attention_mask
if attention_mask is not None: if attention_mask is not None:
@ -773,7 +883,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
if self.class_embedding is not None: if self.class_embedding is not None:
if class_labels is None: if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0") raise ValueError(
"class_labels should be provided when num_class_embeds > 0"
)
if self.config.class_embed_type == "timestep": if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels) class_labels = self.time_proj(class_labels)
@ -812,7 +924,10 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# 3. down # 3. down
down_block_res_samples = (sample,) down_block_res_samples = (sample,)
for downsample_block in self.down_blocks: for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: if (
hasattr(downsample_block, "has_cross_attention")
and downsample_block.has_cross_attention
):
sample, res_samples = downsample_block( sample, res_samples = downsample_block(
hidden_states=sample, hidden_states=sample,
temb=emb, temb=emb,
@ -827,13 +942,20 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# 4. PaintingNet down blocks # 4. PaintingNet down blocks
brushnet_down_block_res_samples = () brushnet_down_block_res_samples = ()
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks): for down_block_res_sample, brushnet_down_block in zip(
down_block_res_samples, self.brushnet_down_blocks
):
down_block_res_sample = brushnet_down_block(down_block_res_sample) down_block_res_sample = brushnet_down_block(down_block_res_sample)
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,) brushnet_down_block_res_samples = brushnet_down_block_res_samples + (
down_block_res_sample,
)
# 5. mid # 5. mid
if self.mid_block is not None: if self.mid_block is not None:
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: if (
hasattr(self.mid_block, "has_cross_attention")
and self.mid_block.has_cross_attention
):
sample = self.mid_block( sample = self.mid_block(
sample, sample,
emb, emb,
@ -852,15 +974,20 @@ class BrushNetModel(ModelMixin, ConfigMixin):
for i, upsample_block in enumerate(self.up_blocks): for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1 is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets):] res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] down_block_res_samples = down_block_res_samples[
: -len(upsample_block.resnets)
]
# if we have not reached the final block and need to forward the # if we have not reached the final block and need to forward the
# upsample size, we do it here # upsample size, we do it here
if not is_final_block: if not is_final_block:
upsample_size = down_block_res_samples[-1].shape[2:] upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: if (
hasattr(upsample_block, "has_cross_attention")
and upsample_block.has_cross_attention
):
sample, up_res_samples = upsample_block( sample, up_res_samples = upsample_block(
hidden_states=sample, hidden_states=sample,
temb=emb, temb=emb,
@ -869,7 +996,7 @@ class BrushNetModel(ModelMixin, ConfigMixin):
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size, upsample_size=upsample_size,
attention_mask=attention_mask, attention_mask=attention_mask,
return_res_samples=True return_res_samples=True,
) )
else: else:
sample, up_res_samples = upsample_block( sample, up_res_samples = upsample_block(
@ -877,53 +1004,87 @@ class BrushNetModel(ModelMixin, ConfigMixin):
temb=emb, temb=emb,
res_hidden_states_tuple=res_samples, res_hidden_states_tuple=res_samples,
upsample_size=upsample_size, upsample_size=upsample_size,
return_res_samples=True return_res_samples=True,
) )
up_block_res_samples += up_res_samples up_block_res_samples += up_res_samples
# 8. BrushNet up blocks # 8. BrushNet up blocks
brushnet_up_block_res_samples = () brushnet_up_block_res_samples = ()
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks): for up_block_res_sample, brushnet_up_block in zip(
up_block_res_samples, self.brushnet_up_blocks
):
up_block_res_sample = brushnet_up_block(up_block_res_sample) up_block_res_sample = brushnet_up_block(up_block_res_sample)
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,) brushnet_up_block_res_samples = brushnet_up_block_res_samples + (
up_block_res_sample,
)
# 6. scaling # 6. scaling
if guess_mode and not self.config.global_pool_conditions: if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, scales = torch.logspace(
len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples), -1,
device=sample.device) # 0.1 to 1.0 0,
len(brushnet_down_block_res_samples)
+ 1
+ len(brushnet_up_block_res_samples),
device=sample.device,
) # 0.1 to 1.0
scales = scales * conditioning_scale scales = scales * conditioning_scale
brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples, brushnet_down_block_res_samples = [
scales[:len( sample * scale
brushnet_down_block_res_samples)])] for sample, scale in zip(
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)] brushnet_down_block_res_samples,
brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples, scales[: len(brushnet_down_block_res_samples)],
scales[ )
len(brushnet_down_block_res_samples) + 1:])] ]
brushnet_mid_block_res_sample = (
brushnet_mid_block_res_sample
* scales[len(brushnet_down_block_res_samples)]
)
brushnet_up_block_res_samples = [
sample * scale
for sample, scale in zip(
brushnet_up_block_res_samples,
scales[len(brushnet_down_block_res_samples) + 1 :],
)
]
else: else:
brushnet_down_block_res_samples = [sample * conditioning_scale for sample in brushnet_down_block_res_samples = [
brushnet_down_block_res_samples] sample * conditioning_scale
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale for sample in brushnet_down_block_res_samples
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples] ]
brushnet_mid_block_res_sample = (
brushnet_mid_block_res_sample * conditioning_scale
)
brushnet_up_block_res_samples = [
sample * conditioning_scale for sample in brushnet_up_block_res_samples
]
if self.config.global_pool_conditions: if self.config.global_pool_conditions:
brushnet_down_block_res_samples = [ brushnet_down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples torch.mean(sample, dim=(2, 3), keepdim=True)
for sample in brushnet_down_block_res_samples
] ]
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True) brushnet_mid_block_res_sample = torch.mean(
brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True
)
brushnet_up_block_res_samples = [ brushnet_up_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples torch.mean(sample, dim=(2, 3), keepdim=True)
for sample in brushnet_up_block_res_samples
] ]
if not return_dict: if not return_dict:
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples) return (
brushnet_down_block_res_samples,
brushnet_mid_block_res_sample,
brushnet_up_block_res_samples,
)
return BrushNetOutput( return BrushNetOutput(
down_block_res_samples=brushnet_down_block_res_samples, down_block_res_samples=brushnet_down_block_res_samples,
mid_block_res_sample=brushnet_mid_block_res_sample, mid_block_res_sample=brushnet_mid_block_res_sample,
up_block_res_samples=brushnet_up_block_res_samples up_block_res_samples=brushnet_up_block_res_samples,
) )

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -122,9 +122,13 @@ class ModelInfo(BaseModel):
@computed_field @computed_field
@property @property
def support_powerpaint_v2(self) -> bool: def support_powerpaint_v2(self) -> bool:
return self.model_type in [ return (
ModelType.DIFFUSERS_SD, self.model_type
] in [
ModelType.DIFFUSERS_SD,
]
and self.name != POWERPAINT_NAME
)
class Choices(str, Enum): class Choices(str, Enum):
@ -215,7 +219,6 @@ class SDSampler(str, Enum):
lcm = "LCM" lcm = "LCM"
class PowerPaintTask(Choices): class PowerPaintTask(Choices):
text_guided = "text-guided" text_guided = "text-guided"
context_aware = "context-aware" context_aware = "context-aware"

View File

@ -59,6 +59,9 @@ const DiffusionOptions = () => {
updateExtenderDirection, updateExtenderDirection,
adjustMask, adjustMask,
clearMask, clearMask,
updateEnablePowerPaintV2,
updateEnableBrushNet,
updateEnableControlnet,
] = useStore((state) => [ ] = useStore((state) => [
state.serverConfig.samplers, state.serverConfig.samplers,
state.settings, state.settings,
@ -71,6 +74,9 @@ const DiffusionOptions = () => {
state.updateExtenderDirection, state.updateExtenderDirection,
state.adjustMask, state.adjustMask,
state.clearMask, state.clearMask,
state.updateEnablePowerPaintV2,
state.updateEnableBrushNet,
state.updateEnableControlnet,
]) ])
const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile) const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile)
const negativePromptRef = useRef(null) const negativePromptRef = useRef(null)
@ -114,12 +120,8 @@ const DiffusionOptions = () => {
return null return null
} }
let disable = settings.enableControlnet
let toolTip = let toolTip =
"BrushNet is a plug-and-play image inpainting model with decomposed dual-branch diffusion. It can be used to inpaint images by conditioning on a mask." "BrushNet is a plug-and-play image inpainting model works on any SD1.5 base models."
if (disable) {
toolTip = "ControlNet is enabled, BrushNet is disabled."
}
return ( return (
<div className="flex flex-col gap-4"> <div className="flex flex-col gap-4">
@ -129,20 +131,19 @@ const DiffusionOptions = () => {
text="BrushNet" text="BrushNet"
url="https://github.com/TencentARC/BrushNet" url="https://github.com/TencentARC/BrushNet"
toolTip={toolTip} toolTip={toolTip}
disabled={disable}
/> />
<Switch <Switch
id="brushnet" id="brushnet"
checked={settings.enableBrushNet} checked={settings.enableBrushNet}
onCheckedChange={(value) => { onCheckedChange={(value) => {
updateSettings({ enableBrushNet: value }) updateEnableBrushNet(value)
}} }}
disabled={disable}
/> />
</RowContainer> </RowContainer>
<RowContainer> {/* <RowContainer>
<Slider <Slider
defaultValue={[100]} defaultValue={[100]}
className="w-[180px]"
min={1} min={1}
max={100} max={100}
step={1} step={1}
@ -155,14 +156,13 @@ const DiffusionOptions = () => {
<NumberInput <NumberInput
id="brushnet-weight" id="brushnet-weight"
className="w-[60px] rounded-full" className="w-[60px] rounded-full"
disabled={!settings.enableBrushNet || disable}
numberValue={settings.brushnetConditioningScale} numberValue={settings.brushnetConditioningScale}
allowFloat={false} allowFloat={false}
onNumberValueChange={(val) => { onNumberValueChange={(val) => {
updateSettings({ brushnetConditioningScale: val }) updateSettings({ brushnetConditioningScale: val })
}} }}
/> />
</RowContainer> </RowContainer> */}
<RowContainer> <RowContainer>
<Select <Select
@ -198,12 +198,8 @@ const DiffusionOptions = () => {
return null return null
} }
let disable = settings.enableBrushNet
let toolTip = let toolTip =
"Using an additional conditioning image to control how an image is generated" "Using an additional conditioning image to control how an image is generated"
if (disable) {
toolTip = "BrushNet is enabled, ControlNet is disabled."
}
return ( return (
<div className="flex flex-col gap-4"> <div className="flex flex-col gap-4">
@ -213,15 +209,13 @@ const DiffusionOptions = () => {
text="ControlNet" text="ControlNet"
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inpaint#controlnet" url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inpaint#controlnet"
toolTip={toolTip} toolTip={toolTip}
disabled={disable}
/> />
<Switch <Switch
id="controlnet" id="controlnet"
checked={settings.enableControlnet} checked={settings.enableControlnet}
onCheckedChange={(value) => { onCheckedChange={(value) => {
updateSettings({ enableControlnet: value }) updateEnableControlnet(value)
}} }}
disabled={disable}
/> />
</RowContainer> </RowContainer>
@ -233,7 +227,7 @@ const DiffusionOptions = () => {
min={1} min={1}
max={100} max={100}
step={1} step={1}
disabled={!settings.enableControlnet || disable} disabled={!settings.enableControlnet}
value={[Math.floor(settings.controlnetConditioningScale * 100)]} value={[Math.floor(settings.controlnetConditioningScale * 100)]}
onValueChange={(vals) => onValueChange={(vals) =>
updateSettings({ controlnetConditioningScale: vals[0] / 100 }) updateSettings({ controlnetConditioningScale: vals[0] / 100 })
@ -242,7 +236,7 @@ const DiffusionOptions = () => {
<NumberInput <NumberInput
id="controlnet-weight" id="controlnet-weight"
className="w-[60px] rounded-full" className="w-[60px] rounded-full"
disabled={!settings.enableControlnet || disable} disabled={!settings.enableControlnet}
numberValue={settings.controlnetConditioningScale} numberValue={settings.controlnetConditioningScale}
allowFloat={false} allowFloat={false}
onNumberValueChange={(val) => { onNumberValueChange={(val) => {
@ -286,12 +280,8 @@ const DiffusionOptions = () => {
return null return null
} }
let disable = settings.enableBrushNet
let toolTip = let toolTip =
"Enable quality image generation in typically 2-4 steps. Suggest disabling guidance_scale by setting it to 0. You can also try values between 1.0 and 2.0. When LCM Lora is enabled, LCMSampler will be used automatically." "Enable quality image generation in typically 2-8 steps. Suggest disabling guidance_scale by setting it to 0. You can also try values between 1.0 and 2.0. When LCM Lora is enabled, LCMSampler will be used automatically."
if (disable) {
toolTip = "BrushNet is enabled, LCM Lora is disabled."
}
return ( return (
<> <>
@ -300,7 +290,6 @@ const DiffusionOptions = () => {
text="LCM LoRA" text="LCM LoRA"
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm_lora" url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm_lora"
toolTip={toolTip} toolTip={toolTip}
disabled={disable}
/> />
<Switch <Switch
id="lcm-lora" id="lcm-lora"
@ -308,7 +297,6 @@ const DiffusionOptions = () => {
onCheckedChange={(value) => { onCheckedChange={(value) => {
updateSettings({ enableLCMLora: value }) updateSettings({ enableLCMLora: value })
}} }}
disabled={disable}
/> />
</RowContainer> </RowContainer>
<Separator /> <Separator />
@ -561,10 +549,6 @@ const DiffusionOptions = () => {
} }
const renderPowerPaintTaskType = () => { const renderPowerPaintTaskType = () => {
if (settings.model.name !== POWERPAINT) {
return null
}
return ( return (
<RowContainer> <RowContainer>
<LabelTitle <LabelTitle
@ -579,7 +563,7 @@ const DiffusionOptions = () => {
}} }}
disabled={settings.showExtender} disabled={settings.showExtender}
> >
<SelectTrigger className="w-[140px]"> <SelectTrigger className="w-[130px]">
<SelectValue placeholder="Select task" /> <SelectValue placeholder="Select task" />
</SelectTrigger> </SelectTrigger>
<SelectContent align="end"> <SelectContent align="end">
@ -587,6 +571,7 @@ const DiffusionOptions = () => {
{[ {[
PowerPaintTask.text_guided, PowerPaintTask.text_guided,
PowerPaintTask.object_remove, PowerPaintTask.object_remove,
PowerPaintTask.context_aware,
PowerPaintTask.shape_guided, PowerPaintTask.shape_guided,
].map((task) => ( ].map((task) => (
<SelectItem key={task} value={task}> <SelectItem key={task} value={task}>
@ -600,6 +585,44 @@ const DiffusionOptions = () => {
) )
} }
const renderPowerPaintV1 = () => {
if (settings.model.name !== POWERPAINT) {
return null
}
return (
<>
{renderPowerPaintTaskType()}
<Separator />
</>
)
}
const renderPowerPaintV2 = () => {
if (settings.model.support_powerpaint_v2 === false) {
return null
}
return (
<>
<RowContainer>
<LabelTitle
text="PowerPaint V2"
toolTip="PowerPaint is a plug-and-play image inpainting model works on any SD1.5 base models."
/>
<Switch
id="powerpaint-v2"
checked={settings.enablePowerPaintV2}
onCheckedChange={(value) => {
updateEnablePowerPaintV2(value)
}}
/>
</RowContainer>
{renderPowerPaintTaskType()}
<Separator />
</>
)
}
const renderSteps = () => { const renderSteps = () => {
return ( return (
<RowContainer> <RowContainer>
@ -868,7 +891,7 @@ const DiffusionOptions = () => {
{renderMaskBlur()} {renderMaskBlur()}
{renderMaskAdjuster()} {renderMaskAdjuster()}
{renderMatchHistograms()} {renderMatchHistograms()}
{renderPowerPaintTaskType()} {renderPowerPaintV1()}
{renderSteps()} {renderSteps()}
{renderGuidanceScale()} {renderGuidanceScale()}
{renderP2PImageGuidanceScale()} {renderP2PImageGuidanceScale()}
@ -878,6 +901,7 @@ const DiffusionOptions = () => {
{renderNegativePrompt()} {renderNegativePrompt()}
<Separator /> <Separator />
{renderBrushNetSetting()} {renderBrushNetSetting()}
{renderPowerPaintV2()}
{renderConterNetSetting()} {renderConterNetSetting()}
{renderLCMLora()} {renderLCMLora()}
{renderPaintByExample()} {renderPaintByExample()}

View File

@ -22,7 +22,7 @@ const SelectTrigger = React.forwardRef<
<SelectPrimitive.Trigger <SelectPrimitive.Trigger
ref={ref} ref={ref}
className={cn( className={cn(
"flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent px-3 py-2 text-sm shadow-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-1 focus:ring-ring disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1", "flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent pl-2 pr-1 py-2 text-sm shadow-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-1 focus:ring-ring disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1",
className className
)} )}
tabIndex={-1} tabIndex={-1}

View File

@ -79,6 +79,7 @@ export default async function inpaint(
enable_brushnet: settings.enableBrushNet, enable_brushnet: settings.enableBrushNet,
brushnet_method: settings.brushnetMethod ? settings.brushnetMethod : "", brushnet_method: settings.brushnetMethod ? settings.brushnetMethod : "",
brushnet_conditioning_scale: settings.brushnetConditioningScale, brushnet_conditioning_scale: settings.brushnetConditioningScale,
enable_powerpaint_v2: settings.enablePowerPaintV2,
powerpaint_task: settings.showExtender powerpaint_task: settings.showExtender
? PowerPaintTask.outpainting ? PowerPaintTask.outpainting
: settings.powerpaintTask, : settings.powerpaintTask,

View File

@ -106,6 +106,7 @@ export type Settings = {
enableLCMLora: boolean enableLCMLora: boolean
// PowerPaint // PowerPaint
enablePowerPaintV2: boolean
powerpaintTask: PowerPaintTask powerpaintTask: PowerPaintTask
// AdjustMask // AdjustMask
@ -194,6 +195,12 @@ type AppAction = {
setServerConfig: (newValue: ServerConfig) => void setServerConfig: (newValue: ServerConfig) => void
setSeed: (newValue: number) => void setSeed: (newValue: number) => void
updateSettings: (newSettings: Partial<Settings>) => void updateSettings: (newSettings: Partial<Settings>) => void
// 互斥
updateEnablePowerPaintV2: (newValue: boolean) => void
updateEnableBrushNet: (newValue: boolean) => void
updateEnableControlnet: (newValue: boolean) => void
setModel: (newModel: ModelInfo) => void setModel: (newModel: ModelInfo) => void
updateFileManagerState: (newState: Partial<FileManagerState>) => void updateFileManagerState: (newState: Partial<FileManagerState>) => void
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
@ -311,6 +318,7 @@ const defaultValues: AppState = {
support_brushnet: false, support_brushnet: false,
support_strength: false, support_strength: false,
support_outpainting: false, support_outpainting: false,
support_powerpaint_v2: false,
controlnets: [], controlnets: [],
brushnets: [], brushnets: [],
support_lcm_lora: false, support_lcm_lora: false,
@ -425,6 +433,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
if ( if (
get().settings.model.support_outpainting && get().settings.model.support_outpainting &&
settings.showExtender && settings.showExtender &&
extenderState.x === 0 &&
extenderState.y === 0 &&
extenderState.height === imageHeight && extenderState.height === imageHeight &&
extenderState.width === imageWidth extenderState.width === imageWidth
) { ) {
@ -798,6 +808,38 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
}) })
}, },
updateEnablePowerPaintV2: (newValue: boolean) => {
get().updateSettings({ enablePowerPaintV2: newValue })
if (newValue) {
get().updateSettings({
enableBrushNet: false,
enableControlnet: false,
enableLCMLora: false,
})
}
},
updateEnableBrushNet: (newValue: boolean) => {
get().updateSettings({ enableBrushNet: newValue })
if (newValue) {
get().updateSettings({
enablePowerPaintV2: false,
enableControlnet: false,
enableLCMLora: false,
})
}
},
updateEnableControlnet(newValue) {
get().updateSettings({ enableControlnet: newValue })
if (newValue) {
get().updateSettings({
enablePowerPaintV2: false,
enableBrushNet: false,
})
}
},
setModel: (newModel: ModelInfo) => { setModel: (newModel: ModelInfo) => {
set((state) => { set((state) => {
state.settings.model = newModel state.settings.model = newModel

View File

@ -49,6 +49,7 @@ export interface ModelInfo {
support_outpainting: boolean support_outpainting: boolean
support_controlnet: boolean support_controlnet: boolean
support_brushnet: boolean support_brushnet: boolean
support_powerpaint_v2: boolean
controlnets: string[] controlnets: string[]
brushnets: string[] brushnets: string[]
support_lcm_lora: boolean support_lcm_lora: boolean
@ -123,6 +124,7 @@ export enum ExtenderDirection {
export enum PowerPaintTask { export enum PowerPaintTask {
text_guided = "text-guided", text_guided = "text-guided",
shape_guided = "shape-guided", shape_guided = "shape-guided",
context_aware = "context-aware",
object_remove = "object-remove", object_remove = "object-remove",
outpainting = "outpainting", outpainting = "outpainting",
} }