This commit is contained in:
Qing 2024-04-29 22:20:44 +08:00
parent 017a3d68fd
commit 80ee1b9941
11 changed files with 1548 additions and 5684 deletions

View File

@ -1,6 +1,9 @@
from itertools import chain
import PIL.Image
import cv2
import torch
from iopaint.model.original_sd_configs import get_config_files
from loguru import logger
from transformers import CLIPTextModel, CLIPTokenizer
import numpy as np
@ -14,9 +17,15 @@ from ..utils import (
handle_from_pretrained_exceptions,
)
from .powerpaint_tokenizer import task_to_prompt
from iopaint.schema import InpaintRequest
from iopaint.schema import InpaintRequest, ModelType
from .v2.BrushNet_CA import BrushNetModel
from .v2.unet_2d_condition import UNet2DConditionModel
from .v2.unet_2d_condition import UNet2DConditionModel_forward
from .v2.unet_2d_blocks import (
CrossAttnDownBlock2D_forward,
DownBlock2D_forward,
CrossAttnUpBlock2D_forward,
UpBlock2D_forward,
)
class PowerPaintV2(DiffusionInpaintModel):
@ -50,14 +59,7 @@ class PowerPaintV2(DiffusionInpaintModel):
torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"],
)
unet = handle_from_pretrained_exceptions(
UNet2DConditionModel.from_pretrained,
pretrained_model_name_or_path=self.model_id_or_path,
subfolder="unet",
variant="fp16",
torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"],
)
brushnet = BrushNetModel.from_pretrained(
self.hf_model_id,
subfolder="PowerPaint_Brushnet",
@ -65,11 +67,27 @@ class PowerPaintV2(DiffusionInpaintModel):
torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"],
)
if self.model_info.is_single_file_diffusers:
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
model_kwargs["num_in_channels"] = 4
else:
model_kwargs["num_in_channels"] = 9
pipe = StableDiffusionPowerPaintBrushNetPipeline.from_single_file(
self.model_id_or_path,
torch_dtype=torch_dtype,
load_safety_checker=False,
original_config_file=get_config_files()["v1"],
brushnet=brushnet,
text_encoder_brushnet=text_encoder_brushnet,
**model_kwargs,
)
else:
pipe = handle_from_pretrained_exceptions(
StableDiffusionPowerPaintBrushNetPipeline.from_pretrained,
pretrained_model_name_or_path=self.model_id_or_path,
torch_dtype=torch_dtype,
unet=unet,
brushnet=brushnet,
text_encoder_brushnet=text_encoder_brushnet,
variant="fp16",
@ -95,6 +113,34 @@ class PowerPaintV2(DiffusionInpaintModel):
self.callback = kwargs.pop("callback", None)
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
self.model.unet.forward = UNet2DConditionModel_forward.__get__(
self.model.unet, self.model.unet.__class__
)
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
for down_block in chain(
self.model.unet.down_blocks, self.model.brushnet.down_blocks
):
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
down_block.forward = CrossAttnDownBlock2D_forward.__get__(
down_block, down_block.__class__
)
else:
down_block.forward = DownBlock2D_forward.__get__(
down_block, down_block.__class__
)
for up_block in chain(self.model.unet.up_blocks, self.model.brushnet.up_blocks):
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
up_block.forward = CrossAttnUpBlock2D_forward.__get__(
up_block, up_block.__class__
)
else:
up_block.forward = UpBlock2D_forward.__get__(
up_block, up_block.__class__
)
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
@ -129,11 +175,10 @@ class PowerPaintV2(DiffusionInpaintModel):
brushnet_conditioning_scale=1.0,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
callback_on_step_end=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
callback_steps=1,
).images[0]
output = (output * 255).round().astype("uint8")

View File

@ -2,6 +2,14 @@ from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from diffusers import UNet2DConditionModel
from diffusers.models.unet_2d_blocks import (
get_down_block,
get_mid_block,
get_up_block,
CrossAttnDownBlock2D,
DownBlock2D,
)
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
@ -13,18 +21,14 @@ from diffusers.models.attention_processor import (
AttnAddedKVProcessor,
AttnProcessor,
)
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, \
TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from .unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
get_down_block,
get_mid_block,
get_up_block
from diffusers.models.embeddings import (
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
TimestepEmbedding,
Timesteps,
)
from .unet_2d_condition import UNet2DConditionModel
from diffusers.models.modeling_utils import ModelMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@ -145,7 +149,10 @@ class BrushNetModel(ModelMixin, ConfigMixin):
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str, ...] = (
"UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
@ -170,7 +177,12 @@ class BrushNetModel(ModelMixin, ConfigMixin):
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
brushnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
16,
32,
96,
256,
),
global_pool_conditions: bool = False,
addition_embed_type_num_heads: int = 64,
):
@ -195,25 +207,33 @@ class BrushNetModel(ModelMixin, ConfigMixin):
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
if not isinstance(only_cross_attention, bool) and len(
only_cross_attention
) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
down_block_types
):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
transformer_layers_per_block = [transformer_layers_per_block] * len(
down_block_types
)
# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in_condition = nn.Conv2d(
in_channels + conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel,
padding=conv_in_padding
in_channels + conditioning_channels,
block_out_channels[0],
kernel_size=conv_in_kernel,
padding=conv_in_padding,
)
# time
@ -229,7 +249,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
logger.info(
"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
)
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
@ -274,7 +296,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
self.class_embedding = TimestepEmbedding(
projection_class_embeddings_input_dim, time_embed_dim
)
else:
self.class_embedding = None
@ -285,21 +309,31 @@ class BrushNetModel(ModelMixin, ConfigMixin):
text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
text_time_embedding_from_dim,
time_embed_dim,
num_heads=addition_embed_type_num_heads,
)
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
text_embed_dim=cross_attention_dim,
image_embed_dim=cross_attention_dim,
time_embed_dim=time_embed_dim,
)
elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
self.add_time_proj = Timesteps(
addition_time_embed_dim, flip_sin_to_cos, freq_shift
)
self.add_embedding = TimestepEmbedding(
projection_class_embeddings_input_dim, time_embed_dim
)
elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
raise ValueError(
f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
)
self.down_blocks = nn.ModuleList([])
self.brushnet_down_blocks = nn.ModuleList([])
@ -338,7 +372,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
attention_head_dim=attention_head_dim[i]
if attention_head_dim[i] is not None
else output_channel,
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
@ -348,12 +384,16 @@ class BrushNetModel(ModelMixin, ConfigMixin):
self.down_blocks.append(down_block)
for _ in range(layers_per_block):
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
brushnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1
)
brushnet_block = zero_module(brushnet_block)
self.brushnet_down_blocks.append(brushnet_block)
if not is_final_block:
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
brushnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1
)
brushnet_block = zero_module(brushnet_block)
self.brushnet_down_blocks.append(brushnet_block)
@ -386,7 +426,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
reversed_transformer_layers_per_block = list(
reversed(transformer_layers_per_block)
)
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
@ -399,7 +441,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
input_channel = reversed_block_out_channels[
min(i + 1, len(block_out_channels) - 1)
]
# add upsample block for all BUT final layer
if not is_final_block:
@ -427,18 +471,24 @@ class BrushNetModel(ModelMixin, ConfigMixin):
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
attention_head_dim=attention_head_dim[i]
if attention_head_dim[i] is not None
else output_channel,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
for _ in range(layers_per_block + 1):
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
brushnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1
)
brushnet_block = zero_module(brushnet_block)
self.brushnet_up_blocks.append(brushnet_block)
if not is_final_block:
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
brushnet_block = nn.Conv2d(
output_channel, output_channel, kernel_size=1
)
brushnet_block = zero_module(brushnet_block)
self.brushnet_up_blocks.append(brushnet_block)
@ -447,7 +497,12 @@ class BrushNetModel(ModelMixin, ConfigMixin):
cls,
unet: UNet2DConditionModel,
brushnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
16,
32,
96,
256,
),
load_weights_from_unet: bool = True,
conditioning_channels: int = 5,
):
@ -460,13 +515,27 @@ class BrushNetModel(ModelMixin, ConfigMixin):
where applicable.
"""
transformer_layers_per_block = (
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
unet.config.transformer_layers_per_block
if "transformer_layers_per_block" in unet.config
else 1
)
encoder_hid_dim = (
unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
)
encoder_hid_dim_type = (
unet.config.encoder_hid_dim_type
if "encoder_hid_dim_type" in unet.config
else None
)
addition_embed_type = (
unet.config.addition_embed_type
if "addition_embed_type" in unet.config
else None
)
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
addition_time_embed_dim = (
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
unet.config.addition_time_embed_dim
if "addition_time_embed_dim" in unet.config
else None
)
brushnet = cls(
@ -475,14 +544,21 @@ class BrushNetModel(ModelMixin, ConfigMixin):
flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift,
# down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
down_block_types=["CrossAttnDownBlock2D",
down_block_types=[
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D", ],
"CrossAttnDownBlock2D",
"DownBlock2D",
],
# mid_block_type='MidBlock2D',
mid_block_type="UNetMidBlock2DCrossAttn",
# up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
up_block_types=[
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
],
only_cross_attention=unet.config.only_cross_attention,
block_out_channels=unet.config.block_out_channels,
layers_per_block=unet.config.layers_per_block,
@ -510,21 +586,33 @@ class BrushNetModel(ModelMixin, ConfigMixin):
)
if load_weights_from_unet:
conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
conv_in_condition_weight = torch.zeros_like(
brushnet.conv_in_condition.weight
)
conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
brushnet.conv_in_condition.weight = torch.nn.Parameter(
conv_in_condition_weight
)
brushnet.conv_in_condition.bias = unet.conv_in.bias
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
if brushnet.class_embedding:
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
brushnet.class_embedding.load_state_dict(
unet.class_embedding.state_dict()
)
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
brushnet.down_blocks.load_state_dict(
unet.down_blocks.state_dict(), strict=False
)
brushnet.mid_block.load_state_dict(
unet.mid_block.state_dict(), strict=False
)
brushnet.up_blocks.load_state_dict(
unet.up_blocks.state_dict(), strict=False
)
return brushnet.to(unet.dtype)
@ -539,9 +627,15 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
def fn_recursive_add_processors(
name: str,
module: torch.nn.Module,
processors: Dict[str, AttentionProcessor],
):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
processors[f"{name}.processor"] = module.get_processor(
return_deprecated_lora=True
)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
@ -554,7 +648,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
):
r"""
Sets the attention processor to use to compute attention.
@ -593,9 +689,15 @@ class BrushNetModel(ModelMixin, ConfigMixin):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
if all(
proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
for proc in self.attn_processors.values()
):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
elif all(
proc.__class__ in CROSS_ATTENTION_PROCESSORS
for proc in self.attn_processors.values()
):
processor = AttnProcessor()
else:
raise ValueError(
@ -642,7 +744,11 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# make smallest slice possible
slice_size = num_sliceable_layers * [1]
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
slice_size = (
num_sliceable_layers * [slice_size]
if not isinstance(slice_size, list)
else slice_size
)
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
@ -659,7 +765,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
def fn_recursive_set_attention_slice(
module: torch.nn.Module, slice_size: List[int]
):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
@ -737,7 +845,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
elif channel_order == "bgr":
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
else:
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
raise ValueError(
f"unknown `brushnet_conditioning_channel_order`: {channel_order}"
)
# prepare attention_mask
if attention_mask is not None:
@ -773,7 +883,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
raise ValueError(
"class_labels should be provided when num_class_embeds > 0"
)
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
@ -812,7 +924,10 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
if (
hasattr(downsample_block, "has_cross_attention")
and downsample_block.has_cross_attention
):
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
@ -827,13 +942,20 @@ class BrushNetModel(ModelMixin, ConfigMixin):
# 4. PaintingNet down blocks
brushnet_down_block_res_samples = ()
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
for down_block_res_sample, brushnet_down_block in zip(
down_block_res_samples, self.brushnet_down_blocks
):
down_block_res_sample = brushnet_down_block(down_block_res_sample)
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (
down_block_res_sample,
)
# 5. mid
if self.mid_block is not None:
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
if (
hasattr(self.mid_block, "has_cross_attention")
and self.mid_block.has_cross_attention
):
sample = self.mid_block(
sample,
emb,
@ -852,15 +974,20 @@ class BrushNetModel(ModelMixin, ConfigMixin):
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[
: -len(upsample_block.resnets)
]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
if (
hasattr(upsample_block, "has_cross_attention")
and upsample_block.has_cross_attention
):
sample, up_res_samples = upsample_block(
hidden_states=sample,
temb=emb,
@ -869,7 +996,7 @@ class BrushNetModel(ModelMixin, ConfigMixin):
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=attention_mask,
return_res_samples=True
return_res_samples=True,
)
else:
sample, up_res_samples = upsample_block(
@ -877,53 +1004,87 @@ class BrushNetModel(ModelMixin, ConfigMixin):
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
return_res_samples=True
return_res_samples=True,
)
up_block_res_samples += up_res_samples
# 8. BrushNet up blocks
brushnet_up_block_res_samples = ()
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
for up_block_res_sample, brushnet_up_block in zip(
up_block_res_samples, self.brushnet_up_blocks
):
up_block_res_sample = brushnet_up_block(up_block_res_sample)
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (
up_block_res_sample,
)
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0,
len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
device=sample.device) # 0.1 to 1.0
scales = torch.logspace(
-1,
0,
len(brushnet_down_block_res_samples)
+ 1
+ len(brushnet_up_block_res_samples),
device=sample.device,
) # 0.1 to 1.0
scales = scales * conditioning_scale
brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples,
scales[:len(
brushnet_down_block_res_samples)])]
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples,
scales[
len(brushnet_down_block_res_samples) + 1:])]
brushnet_down_block_res_samples = [
sample * scale
for sample, scale in zip(
brushnet_down_block_res_samples,
scales[: len(brushnet_down_block_res_samples)],
)
]
brushnet_mid_block_res_sample = (
brushnet_mid_block_res_sample
* scales[len(brushnet_down_block_res_samples)]
)
brushnet_up_block_res_samples = [
sample * scale
for sample, scale in zip(
brushnet_up_block_res_samples,
scales[len(brushnet_down_block_res_samples) + 1 :],
)
]
else:
brushnet_down_block_res_samples = [sample * conditioning_scale for sample in
brushnet_down_block_res_samples]
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
brushnet_down_block_res_samples = [
sample * conditioning_scale
for sample in brushnet_down_block_res_samples
]
brushnet_mid_block_res_sample = (
brushnet_mid_block_res_sample * conditioning_scale
)
brushnet_up_block_res_samples = [
sample * conditioning_scale for sample in brushnet_up_block_res_samples
]
if self.config.global_pool_conditions:
brushnet_down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
torch.mean(sample, dim=(2, 3), keepdim=True)
for sample in brushnet_down_block_res_samples
]
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
brushnet_mid_block_res_sample = torch.mean(
brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True
)
brushnet_up_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
torch.mean(sample, dim=(2, 3), keepdim=True)
for sample in brushnet_up_block_res_samples
]
if not return_dict:
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
return (
brushnet_down_block_res_samples,
brushnet_mid_block_res_sample,
brushnet_up_block_res_samples,
)
return BrushNetOutput(
down_block_res_samples=brushnet_down_block_res_samples,
mid_block_res_sample=brushnet_mid_block_res_sample,
up_block_res_samples=brushnet_up_block_res_samples
up_block_res_samples=brushnet_up_block_res_samples,
)

View File

@ -5,11 +5,21 @@ import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
from diffusers import StableDiffusionMixin
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers import StableDiffusionMixin, UNet2DConditionModel
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.loaders import (
FromSingleFileMixin,
IPAdapterMixin,
LoraLoaderMixin,
TextualInversionLoaderMixin,
)
from diffusers.models import AutoencoderKL, ImageProjection
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.schedulers import KarrasDiffusionSchedulers
@ -21,13 +31,20 @@ from diffusers.utils import (
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
from diffusers.utils.torch_utils import (
is_compiled_module,
is_torch_version,
randn_tensor,
)
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.pipelines.stable_diffusion.pipeline_output import (
StableDiffusionPipelineOutput,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from .BrushNet_CA import BrushNetModel
from .unet_2d_condition import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@ -112,7 +129,9 @@ def retrieve_timesteps(
second element is the number of inference steps.
"""
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
accepts_timesteps = "timesteps" in set(
inspect.signature(scheduler.set_timesteps).parameters.keys()
)
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
@ -220,7 +239,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
self.image_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
)
self.register_to_config(requires_safety_checker=requires_safety_checker)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
@ -301,21 +322,25 @@ class StableDiffusionPowerPaintBrushNetPipeline(
)
text_input_idsA = text_inputsA.input_ids
text_input_idsB = text_inputsB.input_ids
untruncated_ids = self.tokenizer(promptA, padding="longest", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(
promptA, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_idsA.shape[-1] and not torch.equal(
text_input_idsA, untruncated_ids
):
if untruncated_ids.shape[-1] >= text_input_idsA.shape[
-1
] and not torch.equal(text_input_idsA, untruncated_ids):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder_brushnet.config,
"use_attention_mask") and self.text_encoder_brushnet.config.use_attention_mask:
if (
hasattr(self.text_encoder_brushnet.config, "use_attention_mask")
and self.text_encoder_brushnet.config.use_attention_mask
):
attention_mask = text_inputsA.attention_mask.to(device)
else:
attention_mask = None
@ -350,7 +375,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
@ -379,8 +406,12 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# textual inversion: procecss multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
uncond_tokensA = self.maybe_convert_prompt(uncond_tokensA, self.tokenizer)
uncond_tokensB = self.maybe_convert_prompt(uncond_tokensB, self.tokenizer)
uncond_tokensA = self.maybe_convert_prompt(
uncond_tokensA, self.tokenizer
)
uncond_tokensB = self.maybe_convert_prompt(
uncond_tokensB, self.tokenizer
)
max_length = prompt_embeds.shape[1]
uncond_inputA = self.tokenizer(
@ -398,8 +429,10 @@ class StableDiffusionPowerPaintBrushNetPipeline(
return_tensors="pt",
)
if hasattr(self.text_encoder_brushnet.config,
"use_attention_mask") and self.text_encoder_brushnet.config.use_attention_mask:
if (
hasattr(self.text_encoder_brushnet.config, "use_attention_mask")
and self.text_encoder_brushnet.config.use_attention_mask
):
attention_mask = uncond_inputA.attention_mask.to(device)
else:
attention_mask = None
@ -412,7 +445,10 @@ class StableDiffusionPowerPaintBrushNetPipeline(
uncond_inputB.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embedsA[0] * (t_nag) + (1 - t_nag) * negative_prompt_embedsB[0]
negative_prompt_embeds = (
negative_prompt_embedsA[0] * (t_nag)
+ (1 - t_nag) * negative_prompt_embedsB[0]
)
# negative_prompt_embeds = negative_prompt_embeds[0]
@ -420,10 +456,16 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.to(
dtype=prompt_embeds_dtype, device=device
)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_prompt_embeds = negative_prompt_embeds.repeat(
1, num_images_per_prompt, 1
)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
@ -511,30 +553,39 @@ class StableDiffusionPowerPaintBrushNetPipeline(
)
text_input_ids = text_inputs.input_ids
# print(prompt, text_input_ids)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
untruncated_ids = self.tokenizer(
prompt, padding="longest", return_tensors="pt"
).input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
if untruncated_ids.shape[-1] >= text_input_ids.shape[
-1
] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(
untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
)
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
if (
hasattr(self.text_encoder.config, "use_attention_mask")
and self.text_encoder.config.use_attention_mask
):
attention_mask = text_inputs.attention_mask.to(device)
else:
attention_mask = None
if clip_skip is None:
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=attention_mask
)
prompt_embeds = prompt_embeds[0]
else:
prompt_embeds = self.text_encoder(
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
text_input_ids.to(device),
attention_mask=attention_mask,
output_hidden_states=True,
)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
@ -544,7 +595,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
prompt_embeds = self.text_encoder.text_model.final_layer_norm(
prompt_embeds
)
if self.text_encoder is not None:
prompt_embeds_dtype = self.text_encoder.dtype
@ -558,7 +611,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.view(
bs_embed * num_images_per_prompt, seq_len, -1
)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
@ -595,7 +650,10 @@ class StableDiffusionPowerPaintBrushNetPipeline(
)
# print("neg: ", uncond_input.input_ids)
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
if (
hasattr(self.text_encoder.config, "use_attention_mask")
and self.text_encoder.config.use_attention_mask
):
attention_mask = uncond_input.attention_mask.to(device)
else:
attention_mask = None
@ -610,10 +668,16 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.to(
dtype=prompt_embeds_dtype, device=device
)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_prompt_embeds = negative_prompt_embeds.repeat(
1, num_images_per_prompt, 1
)
negative_prompt_embeds = negative_prompt_embeds.view(
batch_size * num_images_per_prompt, seq_len, -1
)
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
@ -624,7 +688,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
def encode_image(
self, image, device, num_images_per_prompt, output_hidden_states=None
):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
@ -632,14 +698,20 @@ class StableDiffusionPowerPaintBrushNetPipeline(
image = image.to(device=device, dtype=dtype)
if output_hidden_states:
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
image_enc_hidden_states = self.image_encoder(
image, output_hidden_states=True
).hidden_states[-2]
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
uncond_image_enc_hidden_states = self.image_encoder(
torch.zeros_like(image), output_hidden_states=True
).hidden_states[-2]
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
uncond_image_enc_hidden_states = (
uncond_image_enc_hidden_states.repeat_interleave(
num_images_per_prompt, dim=0
)
)
return image_enc_hidden_states, uncond_image_enc_hidden_states
else:
image_embeds = self.image_encoder(image).image_embeds
@ -650,13 +722,20 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
self,
ip_adapter_image,
ip_adapter_image_embeds,
device,
num_images_per_prompt,
do_classifier_free_guidance,
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
if len(ip_adapter_image) != len(
self.unet.encoder_hid_proj.image_projection_layers
):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
@ -669,13 +748,17 @@ class StableDiffusionPowerPaintBrushNetPipeline(
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_image_embeds = torch.stack(
[single_image_embeds] * num_images_per_prompt, dim=0
)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)
if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = torch.cat(
[single_negative_image_embeds, single_image_embeds]
)
single_image_embeds = single_image_embeds.to(device)
image_embeds.append(single_image_embeds)
@ -684,17 +767,24 @@ class StableDiffusionPowerPaintBrushNetPipeline(
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_negative_image_embeds, single_image_embeds = (
single_image_embeds.chunk(2)
)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
num_images_per_prompt,
*(repeat_dims * len(single_image_embeds.shape[1:])),
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
num_images_per_prompt,
*(repeat_dims * len(single_negative_image_embeds.shape[1:])),
)
single_image_embeds = torch.cat(
[single_negative_image_embeds, single_image_embeds]
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
num_images_per_prompt,
*(repeat_dims * len(single_image_embeds.shape[1:])),
)
image_embeds.append(single_image_embeds)
@ -706,10 +796,14 @@ class StableDiffusionPowerPaintBrushNetPipeline(
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
feature_extractor_input = self.image_processor.postprocess(
image, output_type="pil"
)
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
safety_checker_input = self.feature_extractor(
feature_extractor_input, return_tensors="pt"
).to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
@ -734,13 +828,17 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
accepts_eta = "eta" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
accepts_generator = "generator" in set(
inspect.signature(self.scheduler.step).parameters.keys()
)
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
@ -761,14 +859,17 @@ class StableDiffusionPowerPaintBrushNetPipeline(
control_guidance_end=1.0,
callback_on_step_end_tensor_inputs=None,
):
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
if callback_steps is not None and (
not isinstance(callback_steps, int) or callback_steps <= 0
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
k in self._callback_tensor_inputs
for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
@ -783,8 +884,12 @@ class StableDiffusionPowerPaintBrushNetPipeline(
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt is not None and (
not isinstance(prompt, str) and not isinstance(prompt, list)
):
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
@ -820,7 +925,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
and isinstance(self.brushnet._orig_mod, BrushNetModel)
):
if not isinstance(brushnet_conditioning_scale, float):
raise TypeError("For single brushnet: `brushnet_conditioning_scale` must be type `float`.")
raise TypeError(
"For single brushnet: `brushnet_conditioning_scale` must be type `float`."
)
else:
assert False
@ -841,9 +948,13 @@ class StableDiffusionPowerPaintBrushNetPipeline(
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
)
if start < 0.0:
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
raise ValueError(
f"control guidance start: {start} can't be smaller than 0."
)
if end > 1.0:
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
raise ValueError(
f"control guidance end: {end} can't be larger than 1.0."
)
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
raise ValueError(
@ -864,8 +975,12 @@ class StableDiffusionPowerPaintBrushNetPipeline(
image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor)
image_is_np = isinstance(image, np.ndarray)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
image_is_pil_list = isinstance(image, list) and isinstance(
image[0], PIL.Image.Image
)
image_is_tensor_list = isinstance(image, list) and isinstance(
image[0], torch.Tensor
)
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
if (
@ -883,8 +998,12 @@ class StableDiffusionPowerPaintBrushNetPipeline(
mask_is_pil = isinstance(mask, PIL.Image.Image)
mask_is_tensor = isinstance(mask, torch.Tensor)
mask_is_np = isinstance(mask, np.ndarray)
mask_is_pil_list = isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image)
mask_is_tensor_list = isinstance(mask, list) and isinstance(mask[0], torch.Tensor)
mask_is_pil_list = isinstance(mask, list) and isinstance(
mask[0], PIL.Image.Image
)
mask_is_tensor_list = isinstance(mask, list) and isinstance(
mask[0], torch.Tensor
)
mask_is_np_list = isinstance(mask, list) and isinstance(mask[0], np.ndarray)
if (
@ -928,7 +1047,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
do_classifier_free_guidance=False,
guess_mode=False,
):
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image = self.image_processor.preprocess(image, height=height, width=width).to(
dtype=torch.float32
)
image_batch_size = image.shape[0]
if image_batch_size == 1:
@ -947,8 +1068,23 @@ class StableDiffusionPowerPaintBrushNetPipeline(
return image.to(device=device, dtype=dtype)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
shape = (
batch_size,
num_channels_latents,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
@ -1186,14 +1322,26 @@ class StableDiffusionPowerPaintBrushNetPipeline(
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
brushnet = self.brushnet._orig_mod if is_compiled_module(self.brushnet) else self.brushnet
brushnet = (
self.brushnet._orig_mod
if is_compiled_module(self.brushnet)
else self.brushnet
)
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
if not isinstance(control_guidance_start, list) and isinstance(
control_guidance_end, list
):
control_guidance_start = len(control_guidance_end) * [
control_guidance_start
]
elif not isinstance(control_guidance_end, list) and isinstance(
control_guidance_start, list
):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
elif not isinstance(control_guidance_start, list) and not isinstance(
control_guidance_end, list
):
control_guidance_start, control_guidance_end = (
[control_guidance_start],
[control_guidance_end],
@ -1241,7 +1389,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# 3. Encode input prompt
text_encoder_lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
self.cross_attention_kwargs.get("scale", None)
if self.cross_attention_kwargs is not None
else None
)
prompt_embeds = self._encode_prompt(
@ -1310,7 +1460,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
assert False
# 5. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps
)
self._num_timesteps = len(timesteps)
# 6. Prepare latent variables
@ -1330,14 +1482,15 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# mask_i = transforms.ToPILImage()(image[0:1,:,:,:].squeeze(0))
# mask_i.save('_mask.png')
# print(brushnet.dtype)
conditioning_latents = self.vae.encode(
image.to(device=device, dtype=brushnet.dtype)).latent_dist.sample() * self.vae.config.scaling_factor
conditioning_latents = (
self.vae.encode(
image.to(device=device, dtype=brushnet.dtype)
).latent_dist.sample()
* self.vae.config.scaling_factor
)
mask = torch.nn.functional.interpolate(
original_mask,
size=(
conditioning_latents.shape[-2],
conditioning_latents.shape[-1]
)
size=(conditioning_latents.shape[-2], conditioning_latents.shape[-1]),
)
conditioning_latents = torch.concat([conditioning_latents, mask], 1)
# image = self.vae.decode(conditioning_latents[:1,:4,:,:] / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
@ -1348,7 +1501,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# 6.5 Optionally get Guidance Scale Embedding
timestep_cond = None
if self.unet.config.time_cond_proj_dim is not None:
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
batch_size * num_images_per_prompt
)
timestep_cond = self.get_guidance_scale_embedding(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)
@ -1370,7 +1525,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
brushnet_keep.append(keeps[0] if isinstance(brushnet, BrushNetModel) else keeps)
brushnet_keep.append(
keeps[0] if isinstance(brushnet, BrushNetModel) else keeps
)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
@ -1381,31 +1538,45 @@ class StableDiffusionPowerPaintBrushNetPipeline(
for i, t in enumerate(timesteps):
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
if (is_unet_compiled and is_brushnet_compiled) and is_torch_higher_equal_2_1:
if (
is_unet_compiled and is_brushnet_compiled
) and is_torch_higher_equal_2_1:
torch._inductor.cudagraph_mark_step_begin()
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
latent_model_input = (
torch.cat([latents] * 2)
if self.do_classifier_free_guidance
else latents
)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
# brushnet(s) inference
if guess_mode and self.do_classifier_free_guidance:
# Infer BrushNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
control_model_input = self.scheduler.scale_model_input(
control_model_input, t
)
brushnet_prompt_embeds = prompt_embeds.chunk(2)[1]
else:
control_model_input = latent_model_input
brushnet_prompt_embeds = prompt_embeds
if isinstance(brushnet_keep[i], list):
cond_scale = [c * s for c, s in zip(brushnet_conditioning_scale, brushnet_keep[i])]
cond_scale = [
c * s
for c, s in zip(brushnet_conditioning_scale, brushnet_keep[i])
]
else:
brushnet_cond_scale = brushnet_conditioning_scale
if isinstance(brushnet_cond_scale, list):
brushnet_cond_scale = brushnet_cond_scale[0]
cond_scale = brushnet_cond_scale * brushnet_keep[i]
down_block_res_samples, mid_block_res_sample, up_block_res_samples = self.brushnet(
down_block_res_samples, mid_block_res_sample, up_block_res_samples = (
self.brushnet(
control_model_input,
t,
encoder_hidden_states=brushnet_prompt_embeds,
@ -1414,14 +1585,23 @@ class StableDiffusionPowerPaintBrushNetPipeline(
guess_mode=guess_mode,
return_dict=False,
)
)
if guess_mode and self.do_classifier_free_guidance:
# Infered BrushNet only for the conditional batch.
# To apply the output of BrushNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
up_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in up_block_res_samples]
down_block_res_samples = [
torch.cat([torch.zeros_like(d), d])
for d in down_block_res_samples
]
mid_block_res_sample = torch.cat(
[torch.zeros_like(mid_block_res_sample), mid_block_res_sample]
)
up_block_res_samples = [
torch.cat([torch.zeros_like(d), d])
for d in up_block_res_samples
]
# predict the noise residual
noise_pred = self.unet(
@ -1440,10 +1620,14 @@ class StableDiffusionPowerPaintBrushNetPipeline(
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = noise_pred_uncond + self.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
@ -1453,10 +1637,14 @@ class StableDiffusionPowerPaintBrushNetPipeline(
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
negative_prompt_embeds = callback_outputs.pop(
"negative_prompt_embeds", negative_prompt_embeds
)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
@ -1470,10 +1658,14 @@ class StableDiffusionPowerPaintBrushNetPipeline(
torch.cuda.empty_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
0
]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
image = self.vae.decode(
latents / self.vae.config.scaling_factor,
return_dict=False,
generator=generator,
)[0]
image, has_nsfw_concept = self.run_safety_checker(
image, device, prompt_embeds.dtype
)
else:
image = latents
has_nsfw_concept = None
@ -1483,7 +1675,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
image = self.image_processor.postprocess(
image, output_type=output_type, do_denormalize=do_denormalize
)
# Offload all models
self.maybe_free_model_hooks()
@ -1491,4 +1685,6 @@ class StableDiffusionPowerPaintBrushNetPipeline(
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
return StableDiffusionPipelineOutput(
images=image, nsfw_content_detected=has_nsfw_concept
)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -122,9 +122,13 @@ class ModelInfo(BaseModel):
@computed_field
@property
def support_powerpaint_v2(self) -> bool:
return self.model_type in [
return (
self.model_type
in [
ModelType.DIFFUSERS_SD,
]
and self.name != POWERPAINT_NAME
)
class Choices(str, Enum):
@ -215,7 +219,6 @@ class SDSampler(str, Enum):
lcm = "LCM"
class PowerPaintTask(Choices):
text_guided = "text-guided"
context_aware = "context-aware"

View File

@ -59,6 +59,9 @@ const DiffusionOptions = () => {
updateExtenderDirection,
adjustMask,
clearMask,
updateEnablePowerPaintV2,
updateEnableBrushNet,
updateEnableControlnet,
] = useStore((state) => [
state.serverConfig.samplers,
state.settings,
@ -71,6 +74,9 @@ const DiffusionOptions = () => {
state.updateExtenderDirection,
state.adjustMask,
state.clearMask,
state.updateEnablePowerPaintV2,
state.updateEnableBrushNet,
state.updateEnableControlnet,
])
const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile)
const negativePromptRef = useRef(null)
@ -114,12 +120,8 @@ const DiffusionOptions = () => {
return null
}
let disable = settings.enableControlnet
let toolTip =
"BrushNet is a plug-and-play image inpainting model with decomposed dual-branch diffusion. It can be used to inpaint images by conditioning on a mask."
if (disable) {
toolTip = "ControlNet is enabled, BrushNet is disabled."
}
"BrushNet is a plug-and-play image inpainting model works on any SD1.5 base models."
return (
<div className="flex flex-col gap-4">
@ -129,20 +131,19 @@ const DiffusionOptions = () => {
text="BrushNet"
url="https://github.com/TencentARC/BrushNet"
toolTip={toolTip}
disabled={disable}
/>
<Switch
id="brushnet"
checked={settings.enableBrushNet}
onCheckedChange={(value) => {
updateSettings({ enableBrushNet: value })
updateEnableBrushNet(value)
}}
disabled={disable}
/>
</RowContainer>
<RowContainer>
{/* <RowContainer>
<Slider
defaultValue={[100]}
className="w-[180px]"
min={1}
max={100}
step={1}
@ -155,14 +156,13 @@ const DiffusionOptions = () => {
<NumberInput
id="brushnet-weight"
className="w-[60px] rounded-full"
disabled={!settings.enableBrushNet || disable}
numberValue={settings.brushnetConditioningScale}
allowFloat={false}
onNumberValueChange={(val) => {
updateSettings({ brushnetConditioningScale: val })
}}
/>
</RowContainer>
</RowContainer> */}
<RowContainer>
<Select
@ -198,12 +198,8 @@ const DiffusionOptions = () => {
return null
}
let disable = settings.enableBrushNet
let toolTip =
"Using an additional conditioning image to control how an image is generated"
if (disable) {
toolTip = "BrushNet is enabled, ControlNet is disabled."
}
return (
<div className="flex flex-col gap-4">
@ -213,15 +209,13 @@ const DiffusionOptions = () => {
text="ControlNet"
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inpaint#controlnet"
toolTip={toolTip}
disabled={disable}
/>
<Switch
id="controlnet"
checked={settings.enableControlnet}
onCheckedChange={(value) => {
updateSettings({ enableControlnet: value })
updateEnableControlnet(value)
}}
disabled={disable}
/>
</RowContainer>
@ -233,7 +227,7 @@ const DiffusionOptions = () => {
min={1}
max={100}
step={1}
disabled={!settings.enableControlnet || disable}
disabled={!settings.enableControlnet}
value={[Math.floor(settings.controlnetConditioningScale * 100)]}
onValueChange={(vals) =>
updateSettings({ controlnetConditioningScale: vals[0] / 100 })
@ -242,7 +236,7 @@ const DiffusionOptions = () => {
<NumberInput
id="controlnet-weight"
className="w-[60px] rounded-full"
disabled={!settings.enableControlnet || disable}
disabled={!settings.enableControlnet}
numberValue={settings.controlnetConditioningScale}
allowFloat={false}
onNumberValueChange={(val) => {
@ -286,12 +280,8 @@ const DiffusionOptions = () => {
return null
}
let disable = settings.enableBrushNet
let toolTip =
"Enable quality image generation in typically 2-4 steps. Suggest disabling guidance_scale by setting it to 0. You can also try values between 1.0 and 2.0. When LCM Lora is enabled, LCMSampler will be used automatically."
if (disable) {
toolTip = "BrushNet is enabled, LCM Lora is disabled."
}
"Enable quality image generation in typically 2-8 steps. Suggest disabling guidance_scale by setting it to 0. You can also try values between 1.0 and 2.0. When LCM Lora is enabled, LCMSampler will be used automatically."
return (
<>
@ -300,7 +290,6 @@ const DiffusionOptions = () => {
text="LCM LoRA"
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm_lora"
toolTip={toolTip}
disabled={disable}
/>
<Switch
id="lcm-lora"
@ -308,7 +297,6 @@ const DiffusionOptions = () => {
onCheckedChange={(value) => {
updateSettings({ enableLCMLora: value })
}}
disabled={disable}
/>
</RowContainer>
<Separator />
@ -561,10 +549,6 @@ const DiffusionOptions = () => {
}
const renderPowerPaintTaskType = () => {
if (settings.model.name !== POWERPAINT) {
return null
}
return (
<RowContainer>
<LabelTitle
@ -579,7 +563,7 @@ const DiffusionOptions = () => {
}}
disabled={settings.showExtender}
>
<SelectTrigger className="w-[140px]">
<SelectTrigger className="w-[130px]">
<SelectValue placeholder="Select task" />
</SelectTrigger>
<SelectContent align="end">
@ -587,6 +571,7 @@ const DiffusionOptions = () => {
{[
PowerPaintTask.text_guided,
PowerPaintTask.object_remove,
PowerPaintTask.context_aware,
PowerPaintTask.shape_guided,
].map((task) => (
<SelectItem key={task} value={task}>
@ -600,6 +585,44 @@ const DiffusionOptions = () => {
)
}
const renderPowerPaintV1 = () => {
if (settings.model.name !== POWERPAINT) {
return null
}
return (
<>
{renderPowerPaintTaskType()}
<Separator />
</>
)
}
const renderPowerPaintV2 = () => {
if (settings.model.support_powerpaint_v2 === false) {
return null
}
return (
<>
<RowContainer>
<LabelTitle
text="PowerPaint V2"
toolTip="PowerPaint is a plug-and-play image inpainting model works on any SD1.5 base models."
/>
<Switch
id="powerpaint-v2"
checked={settings.enablePowerPaintV2}
onCheckedChange={(value) => {
updateEnablePowerPaintV2(value)
}}
/>
</RowContainer>
{renderPowerPaintTaskType()}
<Separator />
</>
)
}
const renderSteps = () => {
return (
<RowContainer>
@ -868,7 +891,7 @@ const DiffusionOptions = () => {
{renderMaskBlur()}
{renderMaskAdjuster()}
{renderMatchHistograms()}
{renderPowerPaintTaskType()}
{renderPowerPaintV1()}
{renderSteps()}
{renderGuidanceScale()}
{renderP2PImageGuidanceScale()}
@ -878,6 +901,7 @@ const DiffusionOptions = () => {
{renderNegativePrompt()}
<Separator />
{renderBrushNetSetting()}
{renderPowerPaintV2()}
{renderConterNetSetting()}
{renderLCMLora()}
{renderPaintByExample()}

View File

@ -22,7 +22,7 @@ const SelectTrigger = React.forwardRef<
<SelectPrimitive.Trigger
ref={ref}
className={cn(
"flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent px-3 py-2 text-sm shadow-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-1 focus:ring-ring disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1",
"flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent pl-2 pr-1 py-2 text-sm shadow-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-1 focus:ring-ring disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1",
className
)}
tabIndex={-1}

View File

@ -79,6 +79,7 @@ export default async function inpaint(
enable_brushnet: settings.enableBrushNet,
brushnet_method: settings.brushnetMethod ? settings.brushnetMethod : "",
brushnet_conditioning_scale: settings.brushnetConditioningScale,
enable_powerpaint_v2: settings.enablePowerPaintV2,
powerpaint_task: settings.showExtender
? PowerPaintTask.outpainting
: settings.powerpaintTask,

View File

@ -106,6 +106,7 @@ export type Settings = {
enableLCMLora: boolean
// PowerPaint
enablePowerPaintV2: boolean
powerpaintTask: PowerPaintTask
// AdjustMask
@ -194,6 +195,12 @@ type AppAction = {
setServerConfig: (newValue: ServerConfig) => void
setSeed: (newValue: number) => void
updateSettings: (newSettings: Partial<Settings>) => void
// 互斥
updateEnablePowerPaintV2: (newValue: boolean) => void
updateEnableBrushNet: (newValue: boolean) => void
updateEnableControlnet: (newValue: boolean) => void
setModel: (newModel: ModelInfo) => void
updateFileManagerState: (newState: Partial<FileManagerState>) => void
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
@ -311,6 +318,7 @@ const defaultValues: AppState = {
support_brushnet: false,
support_strength: false,
support_outpainting: false,
support_powerpaint_v2: false,
controlnets: [],
brushnets: [],
support_lcm_lora: false,
@ -425,6 +433,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
if (
get().settings.model.support_outpainting &&
settings.showExtender &&
extenderState.x === 0 &&
extenderState.y === 0 &&
extenderState.height === imageHeight &&
extenderState.width === imageWidth
) {
@ -798,6 +808,38 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
})
},
updateEnablePowerPaintV2: (newValue: boolean) => {
get().updateSettings({ enablePowerPaintV2: newValue })
if (newValue) {
get().updateSettings({
enableBrushNet: false,
enableControlnet: false,
enableLCMLora: false,
})
}
},
updateEnableBrushNet: (newValue: boolean) => {
get().updateSettings({ enableBrushNet: newValue })
if (newValue) {
get().updateSettings({
enablePowerPaintV2: false,
enableControlnet: false,
enableLCMLora: false,
})
}
},
updateEnableControlnet(newValue) {
get().updateSettings({ enableControlnet: newValue })
if (newValue) {
get().updateSettings({
enablePowerPaintV2: false,
enableBrushNet: false,
})
}
},
setModel: (newModel: ModelInfo) => {
set((state) => {
state.settings.model = newModel

View File

@ -49,6 +49,7 @@ export interface ModelInfo {
support_outpainting: boolean
support_controlnet: boolean
support_brushnet: boolean
support_powerpaint_v2: boolean
controlnets: string[]
brushnets: string[]
support_lcm_lora: boolean
@ -123,6 +124,7 @@ export enum ExtenderDirection {
export enum PowerPaintTask {
text_guided = "text-guided",
shape_guided = "shape-guided",
context_aware = "context-aware",
object_remove = "object-remove",
outpainting = "outpainting",
}