This commit is contained in:
Qing 2024-04-29 22:20:44 +08:00
parent 017a3d68fd
commit 80ee1b9941
11 changed files with 1548 additions and 5684 deletions

View File

@ -1,6 +1,9 @@
from itertools import chain
import PIL.Image import PIL.Image
import cv2 import cv2
import torch import torch
from iopaint.model.original_sd_configs import get_config_files
from loguru import logger from loguru import logger
from transformers import CLIPTextModel, CLIPTokenizer from transformers import CLIPTextModel, CLIPTokenizer
import numpy as np import numpy as np
@ -14,9 +17,15 @@ from ..utils import (
handle_from_pretrained_exceptions, handle_from_pretrained_exceptions,
) )
from .powerpaint_tokenizer import task_to_prompt from .powerpaint_tokenizer import task_to_prompt
from iopaint.schema import InpaintRequest from iopaint.schema import InpaintRequest, ModelType
from .v2.BrushNet_CA import BrushNetModel from .v2.BrushNet_CA import BrushNetModel
from .v2.unet_2d_condition import UNet2DConditionModel from .v2.unet_2d_condition import UNet2DConditionModel_forward
from .v2.unet_2d_blocks import (
CrossAttnDownBlock2D_forward,
DownBlock2D_forward,
CrossAttnUpBlock2D_forward,
UpBlock2D_forward,
)
class PowerPaintV2(DiffusionInpaintModel): class PowerPaintV2(DiffusionInpaintModel):
@ -50,14 +59,7 @@ class PowerPaintV2(DiffusionInpaintModel):
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"], local_files_only=model_kwargs["local_files_only"],
) )
unet = handle_from_pretrained_exceptions(
UNet2DConditionModel.from_pretrained,
pretrained_model_name_or_path=self.model_id_or_path,
subfolder="unet",
variant="fp16",
torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"],
)
brushnet = BrushNetModel.from_pretrained( brushnet = BrushNetModel.from_pretrained(
self.hf_model_id, self.hf_model_id,
subfolder="PowerPaint_Brushnet", subfolder="PowerPaint_Brushnet",
@ -65,11 +67,27 @@ class PowerPaintV2(DiffusionInpaintModel):
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"], local_files_only=model_kwargs["local_files_only"],
) )
if self.model_info.is_single_file_diffusers:
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
model_kwargs["num_in_channels"] = 4
else:
model_kwargs["num_in_channels"] = 9
pipe = StableDiffusionPowerPaintBrushNetPipeline.from_single_file(
self.model_id_or_path,
torch_dtype=torch_dtype,
load_safety_checker=False,
original_config_file=get_config_files()["v1"],
brushnet=brushnet,
text_encoder_brushnet=text_encoder_brushnet,
**model_kwargs,
)
else:
pipe = handle_from_pretrained_exceptions( pipe = handle_from_pretrained_exceptions(
StableDiffusionPowerPaintBrushNetPipeline.from_pretrained, StableDiffusionPowerPaintBrushNetPipeline.from_pretrained,
pretrained_model_name_or_path=self.model_id_or_path, pretrained_model_name_or_path=self.model_id_or_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
unet=unet,
brushnet=brushnet, brushnet=brushnet,
text_encoder_brushnet=text_encoder_brushnet, text_encoder_brushnet=text_encoder_brushnet,
variant="fp16", variant="fp16",
@ -95,6 +113,34 @@ class PowerPaintV2(DiffusionInpaintModel):
self.callback = kwargs.pop("callback", None) self.callback = kwargs.pop("callback", None)
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
self.model.unet.forward = UNet2DConditionModel_forward.__get__(
self.model.unet, self.model.unet.__class__
)
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
for down_block in chain(
self.model.unet.down_blocks, self.model.brushnet.down_blocks
):
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
down_block.forward = CrossAttnDownBlock2D_forward.__get__(
down_block, down_block.__class__
)
else:
down_block.forward = DownBlock2D_forward.__get__(
down_block, down_block.__class__
)
for up_block in chain(self.model.unet.up_blocks, self.model.brushnet.up_blocks):
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
up_block.forward = CrossAttnUpBlock2D_forward.__get__(
up_block, up_block.__class__
)
else:
up_block.forward = UpBlock2D_forward.__get__(
up_block, up_block.__class__
)
def forward(self, image, mask, config: InpaintRequest): def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size """Input image and output image have same size
image: [H, W, C] RGB image: [H, W, C] RGB
@ -129,11 +175,10 @@ class PowerPaintV2(DiffusionInpaintModel):
brushnet_conditioning_scale=1.0, brushnet_conditioning_scale=1.0,
guidance_scale=config.sd_guidance_scale, guidance_scale=config.sd_guidance_scale,
output_type="np", output_type="np",
callback=self.callback, callback_on_step_end=self.callback,
height=img_h, height=img_h,
width=img_w, width=img_w,
generator=torch.manual_seed(config.sd_seed), generator=torch.manual_seed(config.sd_seed),
callback_steps=1,
).images[0] ).images[0]
output = (output * 255).round().astype("uint8") output = (output * 255).round().astype("uint8")

View File

@ -2,6 +2,14 @@ from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from diffusers import UNet2DConditionModel
from diffusers.models.unet_2d_blocks import (
get_down_block,
get_mid_block,
get_up_block,
CrossAttnDownBlock2D,
DownBlock2D,
)
from torch import nn from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.configuration_utils import ConfigMixin, register_to_config
@ -13,18 +21,14 @@ from diffusers.models.attention_processor import (
AttnAddedKVProcessor, AttnAddedKVProcessor,
AttnProcessor, AttnProcessor,
) )
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, \ from diffusers.models.embeddings import (
TimestepEmbedding, Timesteps TextImageProjection,
from diffusers.models.modeling_utils import ModelMixin TextImageTimeEmbedding,
from .unet_2d_blocks import ( TextTimeEmbedding,
CrossAttnDownBlock2D, TimestepEmbedding,
DownBlock2D, Timesteps,
get_down_block,
get_mid_block,
get_up_block
) )
from diffusers.models.modeling_utils import ModelMixin
from .unet_2d_condition import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@ -145,7 +149,10 @@ class BrushNetModel(ModelMixin, ConfigMixin):
), ),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str, ...] = ( up_block_types: Tuple[str, ...] = (
"UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D" "UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
), ),
only_cross_attention: Union[bool, Tuple[bool]] = False, only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
@ -170,7 +177,12 @@ class BrushNetModel(ModelMixin, ConfigMixin):
resnet_time_scale_shift: str = "default", resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None, projection_class_embeddings_input_dim: Optional[int] = None,
brushnet_conditioning_channel_order: str = "rgb", brushnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
16,
32,
96,
256,
),
global_pool_conditions: bool = False, global_pool_conditions: bool = False,
addition_embed_type_num_heads: int = 64, addition_embed_type_num_heads: int = 64,
): ):
@ -195,25 +207,33 @@ class BrushNetModel(ModelMixin, ConfigMixin):
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
) )
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): if not isinstance(only_cross_attention, bool) and len(
only_cross_attention
) != len(down_block_types):
raise ValueError( raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
) )
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
down_block_types
):
raise ValueError( raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
) )
if isinstance(transformer_layers_per_block, int): if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) transformer_layers_per_block = [transformer_layers_per_block] * len(
down_block_types
)
# input # input
conv_in_kernel = 3 conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2 conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in_condition = nn.Conv2d( self.conv_in_condition = nn.Conv2d(
in_channels + conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, in_channels + conditioning_channels,
padding=conv_in_padding block_out_channels[0],
kernel_size=conv_in_kernel,
padding=conv_in_padding,
) )
# time # time
@ -229,7 +249,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
if encoder_hid_dim_type is None and encoder_hid_dim is not None: if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj" encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") logger.info(
"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
)
if encoder_hid_dim is None and encoder_hid_dim_type is not None: if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError( raise ValueError(
@ -274,7 +296,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors. # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) self.class_embedding = TimestepEmbedding(
projection_class_embeddings_input_dim, time_embed_dim
)
else: else:
self.class_embedding = None self.class_embedding = None
@ -285,21 +309,31 @@ class BrushNetModel(ModelMixin, ConfigMixin):
text_time_embedding_from_dim = cross_attention_dim text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding( self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads text_time_embedding_from_dim,
time_embed_dim,
num_heads=addition_embed_type_num_heads,
) )
elif addition_embed_type == "text_image": elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding( self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim text_embed_dim=cross_attention_dim,
image_embed_dim=cross_attention_dim,
time_embed_dim=time_embed_dim,
) )
elif addition_embed_type == "text_time": elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) self.add_time_proj = Timesteps(
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) addition_time_embed_dim, flip_sin_to_cos, freq_shift
)
self.add_embedding = TimestepEmbedding(
projection_class_embeddings_input_dim, time_embed_dim
)
elif addition_embed_type is not None: elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") raise ValueError(
f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
)
self.down_blocks = nn.ModuleList([]) self.down_blocks = nn.ModuleList([])
self.brushnet_down_blocks = nn.ModuleList([]) self.brushnet_down_blocks = nn.ModuleList([])
@ -338,7 +372,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
resnet_groups=norm_num_groups, resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[i], num_attention_heads=num_attention_heads[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, attention_head_dim=attention_head_dim[i]
if attention_head_dim[i] is not None
else output_channel,
downsample_padding=downsample_padding, downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection, use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
@ -348,12 +384,16 @@ class BrushNetModel(ModelMixin, ConfigMixin):
self.down_blocks.append(down_block) self.down_blocks.append(down_block)
for _ in range(layers_per_block): for _ in range(layers_per_block):
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) brushnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1
)
brushnet_block = zero_module(brushnet_block) brushnet_block = zero_module(brushnet_block)
self.brushnet_down_blocks.append(brushnet_block) self.brushnet_down_blocks.append(brushnet_block)
if not is_final_block: if not is_final_block:
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) brushnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1
)
brushnet_block = zero_module(brushnet_block) brushnet_block = zero_module(brushnet_block)
self.brushnet_down_blocks.append(brushnet_block) self.brushnet_down_blocks.append(brushnet_block)
@ -386,7 +426,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# up # up
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block))) reversed_transformer_layers_per_block = list(
reversed(transformer_layers_per_block)
)
only_cross_attention = list(reversed(only_cross_attention)) only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
@ -399,7 +441,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
prev_output_channel = output_channel prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i] output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] input_channel = reversed_block_out_channels[
min(i + 1, len(block_out_channels) - 1)
]
# add upsample block for all BUT final layer # add upsample block for all BUT final layer
if not is_final_block: if not is_final_block:
@ -427,18 +471,24 @@ class BrushNetModel(ModelMixin, ConfigMixin):
only_cross_attention=only_cross_attention[i], only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention, upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift, resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, attention_head_dim=attention_head_dim[i]
if attention_head_dim[i] is not None
else output_channel,
) )
self.up_blocks.append(up_block) self.up_blocks.append(up_block)
prev_output_channel = output_channel prev_output_channel = output_channel
for _ in range(layers_per_block + 1): for _ in range(layers_per_block + 1):
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) brushnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1
)
brushnet_block = zero_module(brushnet_block) brushnet_block = zero_module(brushnet_block)
self.brushnet_up_blocks.append(brushnet_block) self.brushnet_up_blocks.append(brushnet_block)
if not is_final_block: if not is_final_block:
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) brushnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1
)
brushnet_block = zero_module(brushnet_block) brushnet_block = zero_module(brushnet_block)
self.brushnet_up_blocks.append(brushnet_block) self.brushnet_up_blocks.append(brushnet_block)
@ -447,7 +497,12 @@ class BrushNetModel(ModelMixin, ConfigMixin):
cls, cls,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
brushnet_conditioning_channel_order: str = "rgb", brushnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
16,
32,
96,
256,
),
load_weights_from_unet: bool = True, load_weights_from_unet: bool = True,
conditioning_channels: int = 5, conditioning_channels: int = 5,
): ):
@ -460,13 +515,27 @@ class BrushNetModel(ModelMixin, ConfigMixin):
where applicable. where applicable.
""" """
transformer_layers_per_block = ( transformer_layers_per_block = (
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 unet.config.transformer_layers_per_block
if "transformer_layers_per_block" in unet.config
else 1
)
encoder_hid_dim = (
unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
)
encoder_hid_dim_type = (
unet.config.encoder_hid_dim_type
if "encoder_hid_dim_type" in unet.config
else None
)
addition_embed_type = (
unet.config.addition_embed_type
if "addition_embed_type" in unet.config
else None
) )
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
addition_time_embed_dim = ( addition_time_embed_dim = (
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None unet.config.addition_time_embed_dim
if "addition_time_embed_dim" in unet.config
else None
) )
brushnet = cls( brushnet = cls(
@ -475,14 +544,21 @@ class BrushNetModel(ModelMixin, ConfigMixin):
flip_sin_to_cos=unet.config.flip_sin_to_cos, flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift, freq_shift=unet.config.freq_shift,
# down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'], # down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
down_block_types=["CrossAttnDownBlock2D", down_block_types=[
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
"DownBlock2D", ], "CrossAttnDownBlock2D",
"DownBlock2D",
],
# mid_block_type='MidBlock2D', # mid_block_type='MidBlock2D',
mid_block_type="UNetMidBlock2DCrossAttn", mid_block_type="UNetMidBlock2DCrossAttn",
# up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'], # up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], up_block_types=[
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
],
only_cross_attention=unet.config.only_cross_attention, only_cross_attention=unet.config.only_cross_attention,
block_out_channels=unet.config.block_out_channels, block_out_channels=unet.config.block_out_channels,
layers_per_block=unet.config.layers_per_block, layers_per_block=unet.config.layers_per_block,
@ -510,21 +586,33 @@ class BrushNetModel(ModelMixin, ConfigMixin):
) )
if load_weights_from_unet: if load_weights_from_unet:
conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight) conv_in_condition_weight = torch.zeros_like(
brushnet.conv_in_condition.weight
)
conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight) brushnet.conv_in_condition.weight = torch.nn.Parameter(
conv_in_condition_weight
)
brushnet.conv_in_condition.bias = unet.conv_in.bias brushnet.conv_in_condition.bias = unet.conv_in.bias
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict()) brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
if brushnet.class_embedding: if brushnet.class_embedding:
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) brushnet.class_embedding.load_state_dict(
unet.class_embedding.state_dict()
)
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False) brushnet.down_blocks.load_state_dict(
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False) unet.down_blocks.state_dict(), strict=False
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False) )
brushnet.mid_block.load_state_dict(
unet.mid_block.state_dict(), strict=False
)
brushnet.up_blocks.load_state_dict(
unet.up_blocks.state_dict(), strict=False
)
return brushnet.to(unet.dtype) return brushnet.to(unet.dtype)
@ -539,9 +627,15 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# set recursively # set recursively
processors = {} processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): def fn_recursive_add_processors(
name: str,
module: torch.nn.Module,
processors: Dict[str, AttentionProcessor],
):
if hasattr(module, "get_processor"): if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) processors[f"{name}.processor"] = module.get_processor(
return_deprecated_lora=True
)
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@ -554,7 +648,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
return processors return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
@ -593,9 +689,15 @@ class BrushNetModel(ModelMixin, ConfigMixin):
""" """
Disables custom attention processors and sets the default attention implementation. Disables custom attention processors and sets the default attention implementation.
""" """
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): if all(
proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
for proc in self.attn_processors.values()
):
processor = AttnAddedKVProcessor() processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): elif all(
proc.__class__ in CROSS_ATTENTION_PROCESSORS
for proc in self.attn_processors.values()
):
processor = AttnProcessor() processor = AttnProcessor()
else: else:
raise ValueError( raise ValueError(
@ -642,7 +744,11 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# make smallest slice possible # make smallest slice possible
slice_size = num_sliceable_layers * [1] slice_size = num_sliceable_layers * [1]
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size slice_size = (
num_sliceable_layers * [slice_size]
if not isinstance(slice_size, list)
else slice_size
)
if len(slice_size) != len(sliceable_head_dims): if len(slice_size) != len(sliceable_head_dims):
raise ValueError( raise ValueError(
@ -659,7 +765,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# Recursively walk through all the children. # Recursively walk through all the children.
# Any children which exposes the set_attention_slice method # Any children which exposes the set_attention_slice method
# gets the message # gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): def fn_recursive_set_attention_slice(
module: torch.nn.Module, slice_size: List[int]
):
if hasattr(module, "set_attention_slice"): if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop()) module.set_attention_slice(slice_size.pop())
@ -737,7 +845,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
elif channel_order == "bgr": elif channel_order == "bgr":
brushnet_cond = torch.flip(brushnet_cond, dims=[1]) brushnet_cond = torch.flip(brushnet_cond, dims=[1])
else: else:
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}") raise ValueError(
f"unknown `brushnet_conditioning_channel_order`: {channel_order}"
)
# prepare attention_mask # prepare attention_mask
if attention_mask is not None: if attention_mask is not None:
@ -773,7 +883,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
if self.class_embedding is not None: if self.class_embedding is not None:
if class_labels is None: if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0") raise ValueError(
"class_labels should be provided when num_class_embeds > 0"
)
if self.config.class_embed_type == "timestep": if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels) class_labels = self.time_proj(class_labels)
@ -812,7 +924,10 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# 3. down # 3. down
down_block_res_samples = (sample,) down_block_res_samples = (sample,)
for downsample_block in self.down_blocks: for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: if (
hasattr(downsample_block, "has_cross_attention")
and downsample_block.has_cross_attention
):
sample, res_samples = downsample_block( sample, res_samples = downsample_block(
hidden_states=sample, hidden_states=sample,
temb=emb, temb=emb,
@ -827,13 +942,20 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# 4. PaintingNet down blocks # 4. PaintingNet down blocks
brushnet_down_block_res_samples = () brushnet_down_block_res_samples = ()
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks): for down_block_res_sample, brushnet_down_block in zip(
down_block_res_samples, self.brushnet_down_blocks
):
down_block_res_sample = brushnet_down_block(down_block_res_sample) down_block_res_sample = brushnet_down_block(down_block_res_sample)
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,) brushnet_down_block_res_samples = brushnet_down_block_res_samples + (
down_block_res_sample,
)
# 5. mid # 5. mid
if self.mid_block is not None: if self.mid_block is not None:
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: if (
hasattr(self.mid_block, "has_cross_attention")
and self.mid_block.has_cross_attention
):
sample = self.mid_block( sample = self.mid_block(
sample, sample,
emb, emb,
@ -853,14 +975,19 @@ class BrushNetModel(ModelMixin, ConfigMixin):
is_final_block = i == len(self.up_blocks) - 1 is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] down_block_res_samples = down_block_res_samples[
: -len(upsample_block.resnets)
]
# if we have not reached the final block and need to forward the # if we have not reached the final block and need to forward the
# upsample size, we do it here # upsample size, we do it here
if not is_final_block: if not is_final_block:
upsample_size = down_block_res_samples[-1].shape[2:] upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: if (
hasattr(upsample_block, "has_cross_attention")
and upsample_block.has_cross_attention
):
sample, up_res_samples = upsample_block( sample, up_res_samples = upsample_block(
hidden_states=sample, hidden_states=sample,
temb=emb, temb=emb,
@ -869,7 +996,7 @@ class BrushNetModel(ModelMixin, ConfigMixin):
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size, upsample_size=upsample_size,
attention_mask=attention_mask, attention_mask=attention_mask,
return_res_samples=True return_res_samples=True,
) )
else: else:
sample, up_res_samples = upsample_block( sample, up_res_samples = upsample_block(
@ -877,53 +1004,87 @@ class BrushNetModel(ModelMixin, ConfigMixin):
temb=emb, temb=emb,
res_hidden_states_tuple=res_samples, res_hidden_states_tuple=res_samples,
upsample_size=upsample_size, upsample_size=upsample_size,
return_res_samples=True return_res_samples=True,
) )
up_block_res_samples += up_res_samples up_block_res_samples += up_res_samples
# 8. BrushNet up blocks # 8. BrushNet up blocks
brushnet_up_block_res_samples = () brushnet_up_block_res_samples = ()
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks): for up_block_res_sample, brushnet_up_block in zip(
up_block_res_samples, self.brushnet_up_blocks
):
up_block_res_sample = brushnet_up_block(up_block_res_sample) up_block_res_sample = brushnet_up_block(up_block_res_sample)
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,) brushnet_up_block_res_samples = brushnet_up_block_res_samples + (
up_block_res_sample,
)
# 6. scaling # 6. scaling
if guess_mode and not self.config.global_pool_conditions: if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0, scales = torch.logspace(
len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples), -1,
device=sample.device) # 0.1 to 1.0 0,
len(brushnet_down_block_res_samples)
+ 1
+ len(brushnet_up_block_res_samples),
device=sample.device,
) # 0.1 to 1.0
scales = scales * conditioning_scale scales = scales * conditioning_scale
brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples, brushnet_down_block_res_samples = [
scales[:len( sample * scale
brushnet_down_block_res_samples)])] for sample, scale in zip(
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)] brushnet_down_block_res_samples,
brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples, scales[: len(brushnet_down_block_res_samples)],
scales[ )
len(brushnet_down_block_res_samples) + 1:])] ]
brushnet_mid_block_res_sample = (
brushnet_mid_block_res_sample
* scales[len(brushnet_down_block_res_samples)]
)
brushnet_up_block_res_samples = [
sample * scale
for sample, scale in zip(
brushnet_up_block_res_samples,
scales[len(brushnet_down_block_res_samples) + 1 :],
)
]
else: else:
brushnet_down_block_res_samples = [sample * conditioning_scale for sample in brushnet_down_block_res_samples = [
brushnet_down_block_res_samples] sample * conditioning_scale
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale for sample in brushnet_down_block_res_samples
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples] ]
brushnet_mid_block_res_sample = (
brushnet_mid_block_res_sample * conditioning_scale
)
brushnet_up_block_res_samples = [
sample * conditioning_scale for sample in brushnet_up_block_res_samples
]
if self.config.global_pool_conditions: if self.config.global_pool_conditions:
brushnet_down_block_res_samples = [ brushnet_down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples torch.mean(sample, dim=(2, 3), keepdim=True)
for sample in brushnet_down_block_res_samples
] ]
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True) brushnet_mid_block_res_sample = torch.mean(
brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True
)
brushnet_up_block_res_samples = [ brushnet_up_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples torch.mean(sample, dim=(2, 3), keepdim=True)
for sample in brushnet_up_block_res_samples
] ]
if not return_dict: if not return_dict:
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples) return (
brushnet_down_block_res_samples,
brushnet_mid_block_res_sample,
brushnet_up_block_res_samples,
)
return BrushNetOutput( return BrushNetOutput(
down_block_res_samples=brushnet_down_block_res_samples, down_block_res_samples=brushnet_down_block_res_samples,
mid_block_res_sample=brushnet_mid_block_res_sample, mid_block_res_sample=brushnet_mid_block_res_sample,
up_block_res_samples=brushnet_up_block_res_samples up_block_res_samples=brushnet_up_block_res_samples,
) )

View File

@ -5,11 +5,21 @@ import numpy as np
import PIL.Image import PIL.Image
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from diffusers import StableDiffusionMixin from diffusers import StableDiffusionMixin, UNet2DConditionModel
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.loaders import (
FromSingleFileMixin,
IPAdapterMixin,
LoraLoaderMixin,
TextualInversionLoaderMixin,
)
from diffusers.models import AutoencoderKL, ImageProjection from diffusers.models import AutoencoderKL, ImageProjection
from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
@ -21,13 +31,20 @@ from diffusers.utils import (
scale_lora_layers, scale_lora_layers,
unscale_lora_layers, unscale_lora_layers,
) )
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor from diffusers.utils.torch_utils import (
is_compiled_module,
is_torch_version,
randn_tensor,
)
from diffusers.pipelines.pipeline_utils import DiffusionPipeline from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.pipeline_output import (
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker StableDiffusionPipelineOutput,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from .BrushNet_CA import BrushNetModel from .BrushNet_CA import BrushNetModel
from .unet_2d_condition import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@ -112,7 +129,9 @@ def retrieve_timesteps(
second element is the number of inference steps. second element is the number of inference steps.
""" """
if timesteps is not None: if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) accepts_timesteps = "timesteps" in set(
inspect.signature(scheduler.set_timesteps).parameters.keys()
)
if not accepts_timesteps: if not accepts_timesteps:
raise ValueError( raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@ -220,7 +239,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
image_encoder=image_encoder, image_encoder=image_encoder,
) )
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
)
self.register_to_config(requires_safety_checker=requires_safety_checker) self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
@ -301,11 +322,13 @@ class StableDiffusionPowerPaintBrushNetPipeline(
) )
text_input_idsA = text_inputsA.input_ids text_input_idsA = text_inputsA.input_ids
text_input_idsB = text_inputsB.input_ids text_input_idsB = text_inputsB.input_ids
untruncated_ids = self.tokenizer(promptA, padding="longest", return_tensors="pt").input_ids untruncated_ids = self.tokenizer(
promptA, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_idsA.shape[-1] and not torch.equal( if untruncated_ids.shape[-1] >= text_input_idsA.shape[
text_input_idsA, untruncated_ids -1
): ] and not torch.equal(text_input_idsA, untruncated_ids):
removed_text = self.tokenizer.batch_decode( removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
) )
@ -314,8 +337,10 @@ class StableDiffusionPowerPaintBrushNetPipeline(
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
if hasattr(self.text_encoder_brushnet.config, if (
"use_attention_mask") and self.text_encoder_brushnet.config.use_attention_mask: hasattr(self.text_encoder_brushnet.config, "use_attention_mask")
and self.text_encoder_brushnet.config.use_attention_mask
):
attention_mask = text_inputsA.attention_mask.to(device) attention_mask = text_inputsA.attention_mask.to(device)
else: else:
attention_mask = None attention_mask = None
@ -350,7 +375,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None: if do_classifier_free_guidance and negative_prompt_embeds is None:
@ -379,8 +406,12 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# textual inversion: procecss multi-vector tokens if necessary # textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin): if isinstance(self, TextualInversionLoaderMixin):
uncond_tokensA = self.maybe_convert_prompt(uncond_tokensA, self.tokenizer) uncond_tokensA = self.maybe_convert_prompt(
uncond_tokensB = self.maybe_convert_prompt(uncond_tokensB, self.tokenizer) uncond_tokensA, self.tokenizer
)
uncond_tokensB = self.maybe_convert_prompt(
uncond_tokensB, self.tokenizer
)
max_length = prompt_embeds.shape[1] max_length = prompt_embeds.shape[1]
uncond_inputA = self.tokenizer( uncond_inputA = self.tokenizer(
@ -398,8 +429,10 @@ class StableDiffusionPowerPaintBrushNetPipeline(
return_tensors="pt", return_tensors="pt",
) )
if hasattr(self.text_encoder_brushnet.config, if (
"use_attention_mask") and self.text_encoder_brushnet.config.use_attention_mask: hasattr(self.text_encoder_brushnet.config, "use_attention_mask")
and self.text_encoder_brushnet.config.use_attention_mask
):
attention_mask = uncond_inputA.attention_mask.to(device) attention_mask = uncond_inputA.attention_mask.to(device)
else: else:
attention_mask = None attention_mask = None
@ -412,7 +445,10 @@ class StableDiffusionPowerPaintBrushNetPipeline(
uncond_inputB.input_ids.to(device), uncond_inputB.input_ids.to(device),
attention_mask=attention_mask, attention_mask=attention_mask,
) )
negative_prompt_embeds = negative_prompt_embedsA[0] * (t_nag) + (1 - t_nag) * negative_prompt_embedsB[0] negative_prompt_embeds = (
negative_prompt_embedsA[0] * (t_nag)
+ (1 - t_nag) * negative_prompt_embedsB[0]
)
# negative_prompt_embeds = negative_prompt_embeds[0] # negative_prompt_embeds = negative_prompt_embeds[0]
@ -420,10 +456,16 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(
dtype=prompt_embeds_dtype, device=device
)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 1, num_images_per_prompt, 1
)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
# For classifier free guidance, we need to do two forward passes. # For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch # Here we concatenate the unconditional and text embeddings into a single batch
@ -511,11 +553,13 @@ class StableDiffusionPowerPaintBrushNetPipeline(
) )
text_input_ids = text_inputs.input_ids text_input_ids = text_inputs.input_ids
# print(prompt, text_input_ids) # print(prompt, text_input_ids)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids untruncated_ids = self.tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( if untruncated_ids.shape[-1] >= text_input_ids.shape[
text_input_ids, untruncated_ids -1
): ] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode( removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
) )
@ -524,17 +568,24 @@ class StableDiffusionPowerPaintBrushNetPipeline(
f" {self.tokenizer.model_max_length} tokens: {removed_text}" f" {self.tokenizer.model_max_length} tokens: {removed_text}"
) )
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: if (
hasattr(self.text_encoder.config, "use_attention_mask")
and self.text_encoder.config.use_attention_mask
):
attention_mask = text_inputs.attention_mask.to(device) attention_mask = text_inputs.attention_mask.to(device)
else: else:
attention_mask = None attention_mask = None
if clip_skip is None: if clip_skip is None:
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=attention_mask
)
prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds[0]
else: else:
prompt_embeds = self.text_encoder( prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True text_input_ids.to(device),
attention_mask=attention_mask,
output_hidden_states=True,
) )
# Access the `hidden_states` first, that contains a tuple of # Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into # all the hidden states from the encoder layers. Then index into
@ -544,7 +595,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# representations. The `last_hidden_states` that we typically use for # representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm # obtaining the final prompt representations passes through the LayerNorm
# layer. # layer.
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) prompt_embeds = self.text_encoder.text_model.final_layer_norm(
prompt_embeds
)
if self.text_encoder is not None: if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype prompt_embeds_dtype = self.text_encoder.dtype
@ -558,7 +611,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
bs_embed, seq_len, _ = prompt_embeds.shape bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
# get unconditional embeddings for classifier free guidance # get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None: if do_classifier_free_guidance and negative_prompt_embeds is None:
@ -595,7 +650,10 @@ class StableDiffusionPowerPaintBrushNetPipeline(
) )
# print("neg: ", uncond_input.input_ids) # print("neg: ", uncond_input.input_ids)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: if (
hasattr(self.text_encoder.config, "use_attention_mask")
and self.text_encoder.config.use_attention_mask
):
attention_mask = uncond_input.attention_mask.to(device) attention_mask = uncond_input.attention_mask.to(device)
else: else:
attention_mask = None attention_mask = None
@ -610,10 +668,16 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1] seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.to(
dtype=prompt_embeds_dtype, device=device
)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.repeat(
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 1, num_images_per_prompt, 1
)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers # Retrieve the original scale by scaling back the LoRA layers
@ -624,7 +688,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
return prompt_embeds return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): def encode_image(
self, image, device, num_images_per_prompt, output_hidden_states=None
):
dtype = next(self.image_encoder.parameters()).dtype dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor): if not isinstance(image, torch.Tensor):
@ -632,14 +698,20 @@ class StableDiffusionPowerPaintBrushNetPipeline(
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
if output_hidden_states: if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] image_enc_hidden_states = self.image_encoder(
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) image, output_hidden_states=True
).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
uncond_image_enc_hidden_states = self.image_encoder( uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2] ).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( uncond_image_enc_hidden_states = (
uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0 num_images_per_prompt, dim=0
) )
)
return image_enc_hidden_states, uncond_image_enc_hidden_states return image_enc_hidden_states, uncond_image_enc_hidden_states
else: else:
image_embeds = self.image_encoder(image).image_embeds image_embeds = self.image_encoder(image).image_embeds
@ -650,13 +722,20 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds( def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance self,
ip_adapter_image,
ip_adapter_image_embeds,
device,
num_images_per_prompt,
do_classifier_free_guidance,
): ):
if ip_adapter_image_embeds is None: if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list): if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image] ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): if len(ip_adapter_image) != len(
self.unet.encoder_hid_proj.image_projection_layers
):
raise ValueError( raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
) )
@ -669,13 +748,17 @@ class StableDiffusionPowerPaintBrushNetPipeline(
single_image_embeds, single_negative_image_embeds = self.encode_image( single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state single_ip_adapter_image, device, 1, output_hidden_state
) )
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) single_image_embeds = torch.stack(
[single_image_embeds] * num_images_per_prompt, dim=0
)
single_negative_image_embeds = torch.stack( single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0 [single_negative_image_embeds] * num_images_per_prompt, dim=0
) )
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = torch.cat(
[single_negative_image_embeds, single_image_embeds]
)
single_image_embeds = single_image_embeds.to(device) single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
@ -684,17 +767,24 @@ class StableDiffusionPowerPaintBrushNetPipeline(
image_embeds = [] image_embeds = []
for single_image_embeds in ip_adapter_image_embeds: for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance: if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) single_negative_image_embeds, single_image_embeds = (
single_image_embeds.chunk(2)
)
single_image_embeds = single_image_embeds.repeat( single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) num_images_per_prompt,
*(repeat_dims * len(single_image_embeds.shape[1:])),
) )
single_negative_image_embeds = single_negative_image_embeds.repeat( single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) num_images_per_prompt,
*(repeat_dims * len(single_negative_image_embeds.shape[1:])),
)
single_image_embeds = torch.cat(
[single_negative_image_embeds, single_image_embeds]
) )
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else: else:
single_image_embeds = single_image_embeds.repeat( single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) num_images_per_prompt,
*(repeat_dims * len(single_image_embeds.shape[1:])),
) )
image_embeds.append(single_image_embeds) image_embeds.append(single_image_embeds)
@ -706,10 +796,14 @@ class StableDiffusionPowerPaintBrushNetPipeline(
has_nsfw_concept = None has_nsfw_concept = None
else: else:
if torch.is_tensor(image): if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") feature_extractor_input = self.image_processor.postprocess(
image, output_type="pil"
)
else: else:
feature_extractor_input = self.image_processor.numpy_to_pil(image) feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) safety_checker_input = self.feature_extractor(
feature_extractor_input, return_tensors="pt"
).to(device)
image, has_nsfw_concept = self.safety_checker( image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype) images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
) )
@ -734,13 +828,17 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1] # and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {} extra_step_kwargs = {}
if accepts_eta: if accepts_eta:
extra_step_kwargs["eta"] = eta extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator # check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator: if accepts_generator:
extra_step_kwargs["generator"] = generator extra_step_kwargs["generator"] = generator
return extra_step_kwargs return extra_step_kwargs
@ -761,14 +859,17 @@ class StableDiffusionPowerPaintBrushNetPipeline(
control_guidance_end=1.0, control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
): ):
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): if callback_steps is not None and (
not isinstance(callback_steps, int) or callback_steps <= 0
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all( if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs k in self._callback_tensor_inputs
for k in callback_on_step_end_tensor_inputs
): ):
raise ValueError( raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
@ -783,8 +884,12 @@ class StableDiffusionPowerPaintBrushNetPipeline(
raise ValueError( raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
) )
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): elif prompt is not None and (
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") not isinstance(prompt, str) and not isinstance(prompt, list)
):
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
if negative_prompt is not None and negative_prompt_embeds is not None: if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError( raise ValueError(
@ -820,7 +925,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
and isinstance(self.brushnet._orig_mod, BrushNetModel) and isinstance(self.brushnet._orig_mod, BrushNetModel)
): ):
if not isinstance(brushnet_conditioning_scale, float): if not isinstance(brushnet_conditioning_scale, float):
raise TypeError("For single brushnet: `brushnet_conditioning_scale` must be type `float`.") raise TypeError(
"For single brushnet: `brushnet_conditioning_scale` must be type `float`."
)
else: else:
assert False assert False
@ -841,9 +948,13 @@ class StableDiffusionPowerPaintBrushNetPipeline(
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
) )
if start < 0.0: if start < 0.0:
raise ValueError(f"control guidance start: {start} can't be smaller than 0.") raise ValueError(
f"control guidance start: {start} can't be smaller than 0."
)
if end > 1.0: if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") raise ValueError(
f"control guidance end: {end} can't be larger than 1.0."
)
if ip_adapter_image is not None and ip_adapter_image_embeds is not None: if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError( raise ValueError(
@ -864,8 +975,12 @@ class StableDiffusionPowerPaintBrushNetPipeline(
image_is_pil = isinstance(image, PIL.Image.Image) image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor) image_is_tensor = isinstance(image, torch.Tensor)
image_is_np = isinstance(image, np.ndarray) image_is_np = isinstance(image, np.ndarray)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) image_is_pil_list = isinstance(image, list) and isinstance(
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) image[0], PIL.Image.Image
)
image_is_tensor_list = isinstance(image, list) and isinstance(
image[0], torch.Tensor
)
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
if ( if (
@ -883,8 +998,12 @@ class StableDiffusionPowerPaintBrushNetPipeline(
mask_is_pil = isinstance(mask, PIL.Image.Image) mask_is_pil = isinstance(mask, PIL.Image.Image)
mask_is_tensor = isinstance(mask, torch.Tensor) mask_is_tensor = isinstance(mask, torch.Tensor)
mask_is_np = isinstance(mask, np.ndarray) mask_is_np = isinstance(mask, np.ndarray)
mask_is_pil_list = isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image) mask_is_pil_list = isinstance(mask, list) and isinstance(
mask_is_tensor_list = isinstance(mask, list) and isinstance(mask[0], torch.Tensor) mask[0], PIL.Image.Image
)
mask_is_tensor_list = isinstance(mask, list) and isinstance(
mask[0], torch.Tensor
)
mask_is_np_list = isinstance(mask, list) and isinstance(mask[0], np.ndarray) mask_is_np_list = isinstance(mask, list) and isinstance(mask[0], np.ndarray)
if ( if (
@ -928,7 +1047,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
do_classifier_free_guidance=False, do_classifier_free_guidance=False,
guess_mode=False, guess_mode=False,
): ):
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) image = self.image_processor.preprocess(image, height=height, width=width).to(
dtype=torch.float32
)
image_batch_size = image.shape[0] image_batch_size = image.shape[0]
if image_batch_size == 1: if image_batch_size == 1:
@ -947,8 +1068,23 @@ class StableDiffusionPowerPaintBrushNetPipeline(
return image.to(device=device, dtype=dtype) return image.to(device=device, dtype=dtype)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): def prepare_latents(
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
shape = (
batch_size,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@ -1186,14 +1322,26 @@ class StableDiffusionPowerPaintBrushNetPipeline(
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
) )
brushnet = self.brushnet._orig_mod if is_compiled_module(self.brushnet) else self.brushnet brushnet = (
self.brushnet._orig_mod
if is_compiled_module(self.brushnet)
else self.brushnet
)
# align format for control guidance # align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): if not isinstance(control_guidance_start, list) and isinstance(
control_guidance_start = len(control_guidance_end) * [control_guidance_start] control_guidance_end, list
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): ):
control_guidance_start = len(control_guidance_end) * [
control_guidance_start
]
elif not isinstance(control_guidance_end, list) and isinstance(
control_guidance_start, list
):
control_guidance_end = len(control_guidance_start) * [control_guidance_end] control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): elif not isinstance(control_guidance_start, list) and not isinstance(
control_guidance_end, list
):
control_guidance_start, control_guidance_end = ( control_guidance_start, control_guidance_end = (
[control_guidance_start], [control_guidance_start],
[control_guidance_end], [control_guidance_end],
@ -1241,7 +1389,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# 3. Encode input prompt # 3. Encode input prompt
text_encoder_lora_scale = ( text_encoder_lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None self.cross_attention_kwargs.get("scale", None)
if self.cross_attention_kwargs is not None
else None
) )
prompt_embeds = self._encode_prompt( prompt_embeds = self._encode_prompt(
@ -1310,7 +1460,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
assert False assert False
# 5. Prepare timesteps # 5. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps
)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# 6. Prepare latent variables # 6. Prepare latent variables
@ -1330,14 +1482,15 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# mask_i = transforms.ToPILImage()(image[0:1,:,:,:].squeeze(0)) # mask_i = transforms.ToPILImage()(image[0:1,:,:,:].squeeze(0))
# mask_i.save('_mask.png') # mask_i.save('_mask.png')
# print(brushnet.dtype) # print(brushnet.dtype)
conditioning_latents = self.vae.encode( conditioning_latents = (
image.to(device=device, dtype=brushnet.dtype)).latent_dist.sample() * self.vae.config.scaling_factor self.vae.encode(
image.to(device=device, dtype=brushnet.dtype)
).latent_dist.sample()
* self.vae.config.scaling_factor
)
mask = torch.nn.functional.interpolate( mask = torch.nn.functional.interpolate(
original_mask, original_mask,
size=( size=(conditioning_latents.shape[-2], conditioning_latents.shape[-1]),
conditioning_latents.shape[-2],
conditioning_latents.shape[-1]
)
) )
conditioning_latents = torch.concat([conditioning_latents, mask], 1) conditioning_latents = torch.concat([conditioning_latents, mask], 1)
# image = self.vae.decode(conditioning_latents[:1,:4,:,:] / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0] # image = self.vae.decode(conditioning_latents[:1,:4,:,:] / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
@ -1348,7 +1501,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# 6.5 Optionally get Guidance Scale Embedding # 6.5 Optionally get Guidance Scale Embedding
timestep_cond = None timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None: if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
batch_size * num_images_per_prompt
)
timestep_cond = self.get_guidance_scale_embedding( timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype) ).to(device=device, dtype=latents.dtype)
@ -1370,7 +1525,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end) for s, e in zip(control_guidance_start, control_guidance_end)
] ]
brushnet_keep.append(keeps[0] if isinstance(brushnet, BrushNetModel) else keeps) brushnet_keep.append(
keeps[0] if isinstance(brushnet, BrushNetModel) else keeps
)
# 8. Denoising loop # 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
@ -1381,31 +1538,45 @@ class StableDiffusionPowerPaintBrushNetPipeline(
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# Relevant thread: # Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_brushnet_compiled) and is_torch_higher_equal_2_1: if (
is_unet_compiled and is_brushnet_compiled
) and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin() torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = (
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) torch.cat([latents] * 2)
if self.do_classifier_free_guidance
else latents
)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
# brushnet(s) inference # brushnet(s) inference
if guess_mode and self.do_classifier_free_guidance: if guess_mode and self.do_classifier_free_guidance:
# Infer BrushNet only for the conditional batch. # Infer BrushNet only for the conditional batch.
control_model_input = latents control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t) control_model_input = self.scheduler.scale_model_input(
control_model_input, t
)
brushnet_prompt_embeds = prompt_embeds.chunk(2)[1] brushnet_prompt_embeds = prompt_embeds.chunk(2)[1]
else: else:
control_model_input = latent_model_input control_model_input = latent_model_input
brushnet_prompt_embeds = prompt_embeds brushnet_prompt_embeds = prompt_embeds
if isinstance(brushnet_keep[i], list): if isinstance(brushnet_keep[i], list):
cond_scale = [c * s for c, s in zip(brushnet_conditioning_scale, brushnet_keep[i])] cond_scale = [
c * s
for c, s in zip(brushnet_conditioning_scale, brushnet_keep[i])
]
else: else:
brushnet_cond_scale = brushnet_conditioning_scale brushnet_cond_scale = brushnet_conditioning_scale
if isinstance(brushnet_cond_scale, list): if isinstance(brushnet_cond_scale, list):
brushnet_cond_scale = brushnet_cond_scale[0] brushnet_cond_scale = brushnet_cond_scale[0]
cond_scale = brushnet_cond_scale * brushnet_keep[i] cond_scale = brushnet_cond_scale * brushnet_keep[i]
down_block_res_samples, mid_block_res_sample, up_block_res_samples = self.brushnet( down_block_res_samples, mid_block_res_sample, up_block_res_samples = (
self.brushnet(
control_model_input, control_model_input,
t, t,
encoder_hidden_states=brushnet_prompt_embeds, encoder_hidden_states=brushnet_prompt_embeds,
@ -1414,14 +1585,23 @@ class StableDiffusionPowerPaintBrushNetPipeline(
guess_mode=guess_mode, guess_mode=guess_mode,
return_dict=False, return_dict=False,
) )
)
if guess_mode and self.do_classifier_free_guidance: if guess_mode and self.do_classifier_free_guidance:
# Infered BrushNet only for the conditional batch. # Infered BrushNet only for the conditional batch.
# To apply the output of BrushNet to both the unconditional and conditional batches, # To apply the output of BrushNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged. # add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] down_block_res_samples = [
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) torch.cat([torch.zeros_like(d), d])
up_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in up_block_res_samples] for d in down_block_res_samples
]
mid_block_res_sample = torch.cat(
[torch.zeros_like(mid_block_res_sample), mid_block_res_sample]
)
up_block_res_samples = [
torch.cat([torch.zeros_like(d), d])
for d in up_block_res_samples
]
# predict the noise residual # predict the noise residual
noise_pred = self.unet( noise_pred = self.unet(
@ -1440,10 +1620,14 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# perform guidance # perform guidance
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
if callback_on_step_end is not None: if callback_on_step_end is not None:
callback_kwargs = {} callback_kwargs = {}
@ -1453,10 +1637,14 @@ class StableDiffusionPowerPaintBrushNetPipeline(
latents = callback_outputs.pop("latents", latents) latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) negative_prompt_embeds = callback_outputs.pop(
"negative_prompt_embeds", negative_prompt_embeds
)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update() progress_bar.update()
if callback is not None and i % callback_steps == 0: if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1) step_idx = i // getattr(self.scheduler, "order", 1)
@ -1470,10 +1658,14 @@ class StableDiffusionPowerPaintBrushNetPipeline(
torch.cuda.empty_cache() torch.cuda.empty_cache()
if not output_type == "latent": if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ image = self.vae.decode(
0 latents / self.vae.config.scaling_factor,
] return_dict=False,
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) generator=generator,
)[0]
image, has_nsfw_concept = self.run_safety_checker(
image, device, prompt_embeds.dtype
)
else: else:
image = latents image = latents
has_nsfw_concept = None has_nsfw_concept = None
@ -1483,7 +1675,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
else: else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) image = self.image_processor.postprocess(
image, output_type=output_type, do_denormalize=do_denormalize
)
# Offload all models # Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
@ -1491,4 +1685,6 @@ class StableDiffusionPowerPaintBrushNetPipeline(
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) return StableDiffusionPipelineOutput(
images=image, nsfw_content_detected=has_nsfw_concept
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -122,9 +122,13 @@ class ModelInfo(BaseModel):
@computed_field @computed_field
@property @property
def support_powerpaint_v2(self) -> bool: def support_powerpaint_v2(self) -> bool:
return self.model_type in [ return (
self.model_type
in [
ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD,
] ]
and self.name != POWERPAINT_NAME
)
class Choices(str, Enum): class Choices(str, Enum):
@ -215,7 +219,6 @@ class SDSampler(str, Enum):
lcm = "LCM" lcm = "LCM"
class PowerPaintTask(Choices): class PowerPaintTask(Choices):
text_guided = "text-guided" text_guided = "text-guided"
context_aware = "context-aware" context_aware = "context-aware"

View File

@ -59,6 +59,9 @@ const DiffusionOptions = () => {
updateExtenderDirection, updateExtenderDirection,
adjustMask, adjustMask,
clearMask, clearMask,
updateEnablePowerPaintV2,
updateEnableBrushNet,
updateEnableControlnet,
] = useStore((state) => [ ] = useStore((state) => [
state.serverConfig.samplers, state.serverConfig.samplers,
state.settings, state.settings,
@ -71,6 +74,9 @@ const DiffusionOptions = () => {
state.updateExtenderDirection, state.updateExtenderDirection,
state.adjustMask, state.adjustMask,
state.clearMask, state.clearMask,
state.updateEnablePowerPaintV2,
state.updateEnableBrushNet,
state.updateEnableControlnet,
]) ])
const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile) const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile)
const negativePromptRef = useRef(null) const negativePromptRef = useRef(null)
@ -114,12 +120,8 @@ const DiffusionOptions = () => {
return null return null
} }
let disable = settings.enableControlnet
let toolTip = let toolTip =
"BrushNet is a plug-and-play image inpainting model with decomposed dual-branch diffusion. It can be used to inpaint images by conditioning on a mask." "BrushNet is a plug-and-play image inpainting model works on any SD1.5 base models."
if (disable) {
toolTip = "ControlNet is enabled, BrushNet is disabled."
}
return ( return (
<div className="flex flex-col gap-4"> <div className="flex flex-col gap-4">
@ -129,20 +131,19 @@ const DiffusionOptions = () => {
text="BrushNet" text="BrushNet"
url="https://github.com/TencentARC/BrushNet" url="https://github.com/TencentARC/BrushNet"
toolTip={toolTip} toolTip={toolTip}
disabled={disable}
/> />
<Switch <Switch
id="brushnet" id="brushnet"
checked={settings.enableBrushNet} checked={settings.enableBrushNet}
onCheckedChange={(value) => { onCheckedChange={(value) => {
updateSettings({ enableBrushNet: value }) updateEnableBrushNet(value)
}} }}
disabled={disable}
/> />
</RowContainer> </RowContainer>
<RowContainer> {/* <RowContainer>
<Slider <Slider
defaultValue={[100]} defaultValue={[100]}
className="w-[180px]"
min={1} min={1}
max={100} max={100}
step={1} step={1}
@ -155,14 +156,13 @@ const DiffusionOptions = () => {
<NumberInput <NumberInput
id="brushnet-weight" id="brushnet-weight"
className="w-[60px] rounded-full" className="w-[60px] rounded-full"
disabled={!settings.enableBrushNet || disable}
numberValue={settings.brushnetConditioningScale} numberValue={settings.brushnetConditioningScale}
allowFloat={false} allowFloat={false}
onNumberValueChange={(val) => { onNumberValueChange={(val) => {
updateSettings({ brushnetConditioningScale: val }) updateSettings({ brushnetConditioningScale: val })
}} }}
/> />
</RowContainer> </RowContainer> */}
<RowContainer> <RowContainer>
<Select <Select
@ -198,12 +198,8 @@ const DiffusionOptions = () => {
return null return null
} }
let disable = settings.enableBrushNet
let toolTip = let toolTip =
"Using an additional conditioning image to control how an image is generated" "Using an additional conditioning image to control how an image is generated"
if (disable) {
toolTip = "BrushNet is enabled, ControlNet is disabled."
}
return ( return (
<div className="flex flex-col gap-4"> <div className="flex flex-col gap-4">
@ -213,15 +209,13 @@ const DiffusionOptions = () => {
text="ControlNet" text="ControlNet"
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inpaint#controlnet" url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inpaint#controlnet"
toolTip={toolTip} toolTip={toolTip}
disabled={disable}
/> />
<Switch <Switch
id="controlnet" id="controlnet"
checked={settings.enableControlnet} checked={settings.enableControlnet}
onCheckedChange={(value) => { onCheckedChange={(value) => {
updateSettings({ enableControlnet: value }) updateEnableControlnet(value)
}} }}
disabled={disable}
/> />
</RowContainer> </RowContainer>
@ -233,7 +227,7 @@ const DiffusionOptions = () => {
min={1} min={1}
max={100} max={100}
step={1} step={1}
disabled={!settings.enableControlnet || disable} disabled={!settings.enableControlnet}
value={[Math.floor(settings.controlnetConditioningScale * 100)]} value={[Math.floor(settings.controlnetConditioningScale * 100)]}
onValueChange={(vals) => onValueChange={(vals) =>
updateSettings({ controlnetConditioningScale: vals[0] / 100 }) updateSettings({ controlnetConditioningScale: vals[0] / 100 })
@ -242,7 +236,7 @@ const DiffusionOptions = () => {
<NumberInput <NumberInput
id="controlnet-weight" id="controlnet-weight"
className="w-[60px] rounded-full" className="w-[60px] rounded-full"
disabled={!settings.enableControlnet || disable} disabled={!settings.enableControlnet}
numberValue={settings.controlnetConditioningScale} numberValue={settings.controlnetConditioningScale}
allowFloat={false} allowFloat={false}
onNumberValueChange={(val) => { onNumberValueChange={(val) => {
@ -286,12 +280,8 @@ const DiffusionOptions = () => {
return null return null
} }
let disable = settings.enableBrushNet
let toolTip = let toolTip =
"Enable quality image generation in typically 2-4 steps. Suggest disabling guidance_scale by setting it to 0. You can also try values between 1.0 and 2.0. When LCM Lora is enabled, LCMSampler will be used automatically." "Enable quality image generation in typically 2-8 steps. Suggest disabling guidance_scale by setting it to 0. You can also try values between 1.0 and 2.0. When LCM Lora is enabled, LCMSampler will be used automatically."
if (disable) {
toolTip = "BrushNet is enabled, LCM Lora is disabled."
}
return ( return (
<> <>
@ -300,7 +290,6 @@ const DiffusionOptions = () => {
text="LCM LoRA" text="LCM LoRA"
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm_lora" url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm_lora"
toolTip={toolTip} toolTip={toolTip}
disabled={disable}
/> />
<Switch <Switch
id="lcm-lora" id="lcm-lora"
@ -308,7 +297,6 @@ const DiffusionOptions = () => {
onCheckedChange={(value) => { onCheckedChange={(value) => {
updateSettings({ enableLCMLora: value }) updateSettings({ enableLCMLora: value })
}} }}
disabled={disable}
/> />
</RowContainer> </RowContainer>
<Separator /> <Separator />
@ -561,10 +549,6 @@ const DiffusionOptions = () => {
} }
const renderPowerPaintTaskType = () => { const renderPowerPaintTaskType = () => {
if (settings.model.name !== POWERPAINT) {
return null
}
return ( return (
<RowContainer> <RowContainer>
<LabelTitle <LabelTitle
@ -579,7 +563,7 @@ const DiffusionOptions = () => {
}} }}
disabled={settings.showExtender} disabled={settings.showExtender}
> >
<SelectTrigger className="w-[140px]"> <SelectTrigger className="w-[130px]">
<SelectValue placeholder="Select task" /> <SelectValue placeholder="Select task" />
</SelectTrigger> </SelectTrigger>
<SelectContent align="end"> <SelectContent align="end">
@ -587,6 +571,7 @@ const DiffusionOptions = () => {
{[ {[
PowerPaintTask.text_guided, PowerPaintTask.text_guided,
PowerPaintTask.object_remove, PowerPaintTask.object_remove,
PowerPaintTask.context_aware,
PowerPaintTask.shape_guided, PowerPaintTask.shape_guided,
].map((task) => ( ].map((task) => (
<SelectItem key={task} value={task}> <SelectItem key={task} value={task}>
@ -600,6 +585,44 @@ const DiffusionOptions = () => {
) )
} }
const renderPowerPaintV1 = () => {
if (settings.model.name !== POWERPAINT) {
return null
}
return (
<>
{renderPowerPaintTaskType()}
<Separator />
</>
)
}
const renderPowerPaintV2 = () => {
if (settings.model.support_powerpaint_v2 === false) {
return null
}
return (
<>
<RowContainer>
<LabelTitle
text="PowerPaint V2"
toolTip="PowerPaint is a plug-and-play image inpainting model works on any SD1.5 base models."
/>
<Switch
id="powerpaint-v2"
checked={settings.enablePowerPaintV2}
onCheckedChange={(value) => {
updateEnablePowerPaintV2(value)
}}
/>
</RowContainer>
{renderPowerPaintTaskType()}
<Separator />
</>
)
}
const renderSteps = () => { const renderSteps = () => {
return ( return (
<RowContainer> <RowContainer>
@ -868,7 +891,7 @@ const DiffusionOptions = () => {
{renderMaskBlur()} {renderMaskBlur()}
{renderMaskAdjuster()} {renderMaskAdjuster()}
{renderMatchHistograms()} {renderMatchHistograms()}
{renderPowerPaintTaskType()} {renderPowerPaintV1()}
{renderSteps()} {renderSteps()}
{renderGuidanceScale()} {renderGuidanceScale()}
{renderP2PImageGuidanceScale()} {renderP2PImageGuidanceScale()}
@ -878,6 +901,7 @@ const DiffusionOptions = () => {
{renderNegativePrompt()} {renderNegativePrompt()}
<Separator /> <Separator />
{renderBrushNetSetting()} {renderBrushNetSetting()}
{renderPowerPaintV2()}
{renderConterNetSetting()} {renderConterNetSetting()}
{renderLCMLora()} {renderLCMLora()}
{renderPaintByExample()} {renderPaintByExample()}

View File

@ -22,7 +22,7 @@ const SelectTrigger = React.forwardRef<
<SelectPrimitive.Trigger <SelectPrimitive.Trigger
ref={ref} ref={ref}
className={cn( className={cn(
"flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent px-3 py-2 text-sm shadow-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-1 focus:ring-ring disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1", "flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent pl-2 pr-1 py-2 text-sm shadow-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-1 focus:ring-ring disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1",
className className
)} )}
tabIndex={-1} tabIndex={-1}

View File

@ -79,6 +79,7 @@ export default async function inpaint(
enable_brushnet: settings.enableBrushNet, enable_brushnet: settings.enableBrushNet,
brushnet_method: settings.brushnetMethod ? settings.brushnetMethod : "", brushnet_method: settings.brushnetMethod ? settings.brushnetMethod : "",
brushnet_conditioning_scale: settings.brushnetConditioningScale, brushnet_conditioning_scale: settings.brushnetConditioningScale,
enable_powerpaint_v2: settings.enablePowerPaintV2,
powerpaint_task: settings.showExtender powerpaint_task: settings.showExtender
? PowerPaintTask.outpainting ? PowerPaintTask.outpainting
: settings.powerpaintTask, : settings.powerpaintTask,

View File

@ -106,6 +106,7 @@ export type Settings = {
enableLCMLora: boolean enableLCMLora: boolean
// PowerPaint // PowerPaint
enablePowerPaintV2: boolean
powerpaintTask: PowerPaintTask powerpaintTask: PowerPaintTask
// AdjustMask // AdjustMask
@ -194,6 +195,12 @@ type AppAction = {
setServerConfig: (newValue: ServerConfig) => void setServerConfig: (newValue: ServerConfig) => void
setSeed: (newValue: number) => void setSeed: (newValue: number) => void
updateSettings: (newSettings: Partial<Settings>) => void updateSettings: (newSettings: Partial<Settings>) => void
// 互斥
updateEnablePowerPaintV2: (newValue: boolean) => void
updateEnableBrushNet: (newValue: boolean) => void
updateEnableControlnet: (newValue: boolean) => void
setModel: (newModel: ModelInfo) => void setModel: (newModel: ModelInfo) => void
updateFileManagerState: (newState: Partial<FileManagerState>) => void updateFileManagerState: (newState: Partial<FileManagerState>) => void
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
@ -311,6 +318,7 @@ const defaultValues: AppState = {
support_brushnet: false, support_brushnet: false,
support_strength: false, support_strength: false,
support_outpainting: false, support_outpainting: false,
support_powerpaint_v2: false,
controlnets: [], controlnets: [],
brushnets: [], brushnets: [],
support_lcm_lora: false, support_lcm_lora: false,
@ -425,6 +433,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
if ( if (
get().settings.model.support_outpainting && get().settings.model.support_outpainting &&
settings.showExtender && settings.showExtender &&
extenderState.x === 0 &&
extenderState.y === 0 &&
extenderState.height === imageHeight && extenderState.height === imageHeight &&
extenderState.width === imageWidth extenderState.width === imageWidth
) { ) {
@ -798,6 +808,38 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
}) })
}, },
updateEnablePowerPaintV2: (newValue: boolean) => {
get().updateSettings({ enablePowerPaintV2: newValue })
if (newValue) {
get().updateSettings({
enableBrushNet: false,
enableControlnet: false,
enableLCMLora: false,
})
}
},
updateEnableBrushNet: (newValue: boolean) => {
get().updateSettings({ enableBrushNet: newValue })
if (newValue) {
get().updateSettings({
enablePowerPaintV2: false,
enableControlnet: false,
enableLCMLora: false,
})
}
},
updateEnableControlnet(newValue) {
get().updateSettings({ enableControlnet: newValue })
if (newValue) {
get().updateSettings({
enablePowerPaintV2: false,
enableBrushNet: false,
})
}
},
setModel: (newModel: ModelInfo) => { setModel: (newModel: ModelInfo) => {
set((state) => { set((state) => {
state.settings.model = newModel state.settings.model = newModel

View File

@ -49,6 +49,7 @@ export interface ModelInfo {
support_outpainting: boolean support_outpainting: boolean
support_controlnet: boolean support_controlnet: boolean
support_brushnet: boolean support_brushnet: boolean
support_powerpaint_v2: boolean
controlnets: string[] controlnets: string[]
brushnets: string[] brushnets: string[]
support_lcm_lora: boolean support_lcm_lora: boolean
@ -123,6 +124,7 @@ export enum ExtenderDirection {
export enum PowerPaintTask { export enum PowerPaintTask {
text_guided = "text-guided", text_guided = "text-guided",
shape_guided = "shape-guided", shape_guided = "shape-guided",
context_aware = "context-aware",
object_remove = "object-remove", object_remove = "object-remove",
outpainting = "outpainting", outpainting = "outpainting",
} }