update
This commit is contained in:
parent
017a3d68fd
commit
80ee1b9941
@ -1,6 +1,9 @@
|
||||
from itertools import chain
|
||||
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import torch
|
||||
from iopaint.model.original_sd_configs import get_config_files
|
||||
from loguru import logger
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
import numpy as np
|
||||
@ -14,9 +17,15 @@ from ..utils import (
|
||||
handle_from_pretrained_exceptions,
|
||||
)
|
||||
from .powerpaint_tokenizer import task_to_prompt
|
||||
from iopaint.schema import InpaintRequest
|
||||
from iopaint.schema import InpaintRequest, ModelType
|
||||
from .v2.BrushNet_CA import BrushNetModel
|
||||
from .v2.unet_2d_condition import UNet2DConditionModel
|
||||
from .v2.unet_2d_condition import UNet2DConditionModel_forward
|
||||
from .v2.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D_forward,
|
||||
DownBlock2D_forward,
|
||||
CrossAttnUpBlock2D_forward,
|
||||
UpBlock2D_forward,
|
||||
)
|
||||
|
||||
|
||||
class PowerPaintV2(DiffusionInpaintModel):
|
||||
@ -50,14 +59,7 @@ class PowerPaintV2(DiffusionInpaintModel):
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=model_kwargs["local_files_only"],
|
||||
)
|
||||
unet = handle_from_pretrained_exceptions(
|
||||
UNet2DConditionModel.from_pretrained,
|
||||
pretrained_model_name_or_path=self.model_id_or_path,
|
||||
subfolder="unet",
|
||||
variant="fp16",
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=model_kwargs["local_files_only"],
|
||||
)
|
||||
|
||||
brushnet = BrushNetModel.from_pretrained(
|
||||
self.hf_model_id,
|
||||
subfolder="PowerPaint_Brushnet",
|
||||
@ -65,16 +67,32 @@ class PowerPaintV2(DiffusionInpaintModel):
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=model_kwargs["local_files_only"],
|
||||
)
|
||||
pipe = handle_from_pretrained_exceptions(
|
||||
StableDiffusionPowerPaintBrushNetPipeline.from_pretrained,
|
||||
pretrained_model_name_or_path=self.model_id_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
unet=unet,
|
||||
brushnet=brushnet,
|
||||
text_encoder_brushnet=text_encoder_brushnet,
|
||||
variant="fp16",
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if self.model_info.is_single_file_diffusers:
|
||||
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||
model_kwargs["num_in_channels"] = 4
|
||||
else:
|
||||
model_kwargs["num_in_channels"] = 9
|
||||
|
||||
pipe = StableDiffusionPowerPaintBrushNetPipeline.from_single_file(
|
||||
self.model_id_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
load_safety_checker=False,
|
||||
original_config_file=get_config_files()["v1"],
|
||||
brushnet=brushnet,
|
||||
text_encoder_brushnet=text_encoder_brushnet,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
pipe = handle_from_pretrained_exceptions(
|
||||
StableDiffusionPowerPaintBrushNetPipeline.from_pretrained,
|
||||
pretrained_model_name_or_path=self.model_id_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
brushnet=brushnet,
|
||||
text_encoder_brushnet=text_encoder_brushnet,
|
||||
variant="fp16",
|
||||
**model_kwargs,
|
||||
)
|
||||
pipe.tokenizer = PowerPaintTokenizer(
|
||||
CLIPTokenizer.from_pretrained(self.hf_model_id, subfolder="tokenizer")
|
||||
)
|
||||
@ -95,6 +113,34 @@ class PowerPaintV2(DiffusionInpaintModel):
|
||||
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
|
||||
self.model.unet.forward = UNet2DConditionModel_forward.__get__(
|
||||
self.model.unet, self.model.unet.__class__
|
||||
)
|
||||
|
||||
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
|
||||
for down_block in chain(
|
||||
self.model.unet.down_blocks, self.model.brushnet.down_blocks
|
||||
):
|
||||
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
|
||||
down_block.forward = CrossAttnDownBlock2D_forward.__get__(
|
||||
down_block, down_block.__class__
|
||||
)
|
||||
else:
|
||||
down_block.forward = DownBlock2D_forward.__get__(
|
||||
down_block, down_block.__class__
|
||||
)
|
||||
|
||||
for up_block in chain(self.model.unet.up_blocks, self.model.brushnet.up_blocks):
|
||||
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
||||
up_block.forward = CrossAttnUpBlock2D_forward.__get__(
|
||||
up_block, up_block.__class__
|
||||
)
|
||||
else:
|
||||
up_block.forward = UpBlock2D_forward.__get__(
|
||||
up_block, up_block.__class__
|
||||
)
|
||||
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
@ -129,11 +175,10 @@ class PowerPaintV2(DiffusionInpaintModel):
|
||||
brushnet_conditioning_scale=1.0,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np",
|
||||
callback=self.callback,
|
||||
callback_on_step_end=self.callback,
|
||||
height=img_h,
|
||||
width=img_w,
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
callback_steps=1,
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
|
@ -2,6 +2,14 @@ from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.unet_2d_blocks import (
|
||||
get_down_block,
|
||||
get_mid_block,
|
||||
get_up_block,
|
||||
CrossAttnDownBlock2D,
|
||||
DownBlock2D,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
@ -13,18 +21,14 @@ from diffusers.models.attention_processor import (
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, \
|
||||
TimestepEmbedding, Timesteps
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
DownBlock2D,
|
||||
get_down_block,
|
||||
get_mid_block,
|
||||
get_up_block
|
||||
from diffusers.models.embeddings import (
|
||||
TextImageProjection,
|
||||
TextImageTimeEmbedding,
|
||||
TextTimeEmbedding,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
)
|
||||
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@ -132,47 +136,55 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 5,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
brushnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 5,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
brushnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
|
||||
16,
|
||||
32,
|
||||
96,
|
||||
256,
|
||||
),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -195,25 +207,33 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||
if not isinstance(only_cross_attention, bool) and len(
|
||||
only_cross_attention
|
||||
) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
|
||||
down_block_types
|
||||
):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(
|
||||
down_block_types
|
||||
)
|
||||
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in_condition = nn.Conv2d(
|
||||
in_channels + conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel,
|
||||
padding=conv_in_padding
|
||||
in_channels + conditioning_channels,
|
||||
block_out_channels[0],
|
||||
kernel_size=conv_in_kernel,
|
||||
padding=conv_in_padding,
|
||||
)
|
||||
|
||||
# time
|
||||
@ -229,7 +249,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
||||
encoder_hid_dim_type = "text_proj"
|
||||
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
||||
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
||||
logger.info(
|
||||
"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
|
||||
)
|
||||
|
||||
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
@ -274,7 +296,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
||||
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
||||
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
||||
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
self.class_embedding = TimestepEmbedding(
|
||||
projection_class_embeddings_input_dim, time_embed_dim
|
||||
)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
@ -285,21 +309,31 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
text_time_embedding_from_dim = cross_attention_dim
|
||||
|
||||
self.add_embedding = TextTimeEmbedding(
|
||||
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
||||
text_time_embedding_from_dim,
|
||||
time_embed_dim,
|
||||
num_heads=addition_embed_type_num_heads,
|
||||
)
|
||||
elif addition_embed_type == "text_image":
|
||||
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
||||
self.add_embedding = TextImageTimeEmbedding(
|
||||
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
||||
text_embed_dim=cross_attention_dim,
|
||||
image_embed_dim=cross_attention_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
)
|
||||
elif addition_embed_type == "text_time":
|
||||
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
||||
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
self.add_time_proj = Timesteps(
|
||||
addition_time_embed_dim, flip_sin_to_cos, freq_shift
|
||||
)
|
||||
self.add_embedding = TimestepEmbedding(
|
||||
projection_class_embeddings_input_dim, time_embed_dim
|
||||
)
|
||||
|
||||
elif addition_embed_type is not None:
|
||||
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
||||
raise ValueError(
|
||||
f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.brushnet_down_blocks = nn.ModuleList([])
|
||||
@ -338,7 +372,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
attention_head_dim=attention_head_dim[i]
|
||||
if attention_head_dim[i] is not None
|
||||
else output_channel,
|
||||
downsample_padding=downsample_padding,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
@ -348,12 +384,16 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
for _ in range(layers_per_block):
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = nn.Conv2d(
|
||||
output_channel, output_channel, kernel_size=1
|
||||
)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_down_blocks.append(brushnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = nn.Conv2d(
|
||||
output_channel, output_channel, kernel_size=1
|
||||
)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_down_blocks.append(brushnet_block)
|
||||
|
||||
@ -386,7 +426,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||
reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
|
||||
reversed_transformer_layers_per_block = list(
|
||||
reversed(transformer_layers_per_block)
|
||||
)
|
||||
only_cross_attention = list(reversed(only_cross_attention))
|
||||
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
@ -399,7 +441,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
input_channel = reversed_block_out_channels[
|
||||
min(i + 1, len(block_out_channels) - 1)
|
||||
]
|
||||
|
||||
# add upsample block for all BUT final layer
|
||||
if not is_final_block:
|
||||
@ -427,29 +471,40 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
attention_head_dim=attention_head_dim[i]
|
||||
if attention_head_dim[i] is not None
|
||||
else output_channel,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
for _ in range(layers_per_block + 1):
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = nn.Conv2d(
|
||||
output_channel, output_channel, kernel_size=1
|
||||
)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_up_blocks.append(brushnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = nn.Conv2d(
|
||||
output_channel, output_channel, kernel_size=1
|
||||
)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_up_blocks.append(brushnet_block)
|
||||
|
||||
@classmethod
|
||||
def from_unet(
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
brushnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
load_weights_from_unet: bool = True,
|
||||
conditioning_channels: int = 5,
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
brushnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
|
||||
16,
|
||||
32,
|
||||
96,
|
||||
256,
|
||||
),
|
||||
load_weights_from_unet: bool = True,
|
||||
conditioning_channels: int = 5,
|
||||
):
|
||||
r"""
|
||||
Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
|
||||
@ -460,13 +515,27 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
where applicable.
|
||||
"""
|
||||
transformer_layers_per_block = (
|
||||
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
||||
unet.config.transformer_layers_per_block
|
||||
if "transformer_layers_per_block" in unet.config
|
||||
else 1
|
||||
)
|
||||
encoder_hid_dim = (
|
||||
unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
||||
)
|
||||
encoder_hid_dim_type = (
|
||||
unet.config.encoder_hid_dim_type
|
||||
if "encoder_hid_dim_type" in unet.config
|
||||
else None
|
||||
)
|
||||
addition_embed_type = (
|
||||
unet.config.addition_embed_type
|
||||
if "addition_embed_type" in unet.config
|
||||
else None
|
||||
)
|
||||
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
||||
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
||||
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
||||
addition_time_embed_dim = (
|
||||
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
||||
unet.config.addition_time_embed_dim
|
||||
if "addition_time_embed_dim" in unet.config
|
||||
else None
|
||||
)
|
||||
|
||||
brushnet = cls(
|
||||
@ -475,14 +544,21 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||
freq_shift=unet.config.freq_shift,
|
||||
# down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
|
||||
down_block_types=["CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D", ],
|
||||
down_block_types=[
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
],
|
||||
# mid_block_type='MidBlock2D',
|
||||
mid_block_type="UNetMidBlock2DCrossAttn",
|
||||
# up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
|
||||
up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
|
||||
up_block_types=[
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
],
|
||||
only_cross_attention=unet.config.only_cross_attention,
|
||||
block_out_channels=unet.config.block_out_channels,
|
||||
layers_per_block=unet.config.layers_per_block,
|
||||
@ -510,21 +586,33 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
if load_weights_from_unet:
|
||||
conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
|
||||
conv_in_condition_weight = torch.zeros_like(
|
||||
brushnet.conv_in_condition.weight
|
||||
)
|
||||
conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
|
||||
conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
|
||||
brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
|
||||
brushnet.conv_in_condition.weight = torch.nn.Parameter(
|
||||
conv_in_condition_weight
|
||||
)
|
||||
brushnet.conv_in_condition.bias = unet.conv_in.bias
|
||||
|
||||
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
||||
|
||||
if brushnet.class_embedding:
|
||||
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
||||
brushnet.class_embedding.load_state_dict(
|
||||
unet.class_embedding.state_dict()
|
||||
)
|
||||
|
||||
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
|
||||
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
|
||||
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
|
||||
brushnet.down_blocks.load_state_dict(
|
||||
unet.down_blocks.state_dict(), strict=False
|
||||
)
|
||||
brushnet.mid_block.load_state_dict(
|
||||
unet.mid_block.state_dict(), strict=False
|
||||
)
|
||||
brushnet.up_blocks.load_state_dict(
|
||||
unet.up_blocks.state_dict(), strict=False
|
||||
)
|
||||
|
||||
return brushnet.to(unet.dtype)
|
||||
|
||||
@ -539,9 +627,15 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(
|
||||
name: str,
|
||||
module: torch.nn.Module,
|
||||
processors: Dict[str, AttentionProcessor],
|
||||
):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
processors[f"{name}.processor"] = module.get_processor(
|
||||
return_deprecated_lora=True
|
||||
)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
@ -554,7 +648,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@ -593,9 +689,15 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
if all(
|
||||
proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
|
||||
for proc in self.attn_processors.values()
|
||||
):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
elif all(
|
||||
proc.__class__ in CROSS_ATTENTION_PROCESSORS
|
||||
for proc in self.attn_processors.values()
|
||||
):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -642,7 +744,11 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
# make smallest slice possible
|
||||
slice_size = num_sliceable_layers * [1]
|
||||
|
||||
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||
slice_size = (
|
||||
num_sliceable_layers * [slice_size]
|
||||
if not isinstance(slice_size, list)
|
||||
else slice_size
|
||||
)
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
@ -659,7 +765,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
def fn_recursive_set_attention_slice(
|
||||
module: torch.nn.Module, slice_size: List[int]
|
||||
):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
@ -675,19 +783,19 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
brushnet_cond: torch.FloatTensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
brushnet_cond: torch.FloatTensor,
|
||||
conditioning_scale: float = 1.0,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
timestep_cond: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
guess_mode: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
|
||||
"""
|
||||
The [`BrushNetModel`] forward method.
|
||||
@ -737,7 +845,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
elif channel_order == "bgr":
|
||||
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
|
||||
else:
|
||||
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
|
||||
raise ValueError(
|
||||
f"unknown `brushnet_conditioning_channel_order`: {channel_order}"
|
||||
)
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
@ -773,7 +883,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
raise ValueError(
|
||||
"class_labels should be provided when num_class_embeds > 0"
|
||||
)
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
@ -812,7 +924,10 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
if (
|
||||
hasattr(downsample_block, "has_cross_attention")
|
||||
and downsample_block.has_cross_attention
|
||||
):
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
@ -827,13 +942,20 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 4. PaintingNet down blocks
|
||||
brushnet_down_block_res_samples = ()
|
||||
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
|
||||
for down_block_res_sample, brushnet_down_block in zip(
|
||||
down_block_res_samples, self.brushnet_down_blocks
|
||||
):
|
||||
down_block_res_sample = brushnet_down_block(down_block_res_sample)
|
||||
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
|
||||
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (
|
||||
down_block_res_sample,
|
||||
)
|
||||
|
||||
# 5. mid
|
||||
if self.mid_block is not None:
|
||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||
if (
|
||||
hasattr(self.mid_block, "has_cross_attention")
|
||||
and self.mid_block.has_cross_attention
|
||||
):
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
@ -852,15 +974,20 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[
|
||||
: -len(upsample_block.resnets)
|
||||
]
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||
if (
|
||||
hasattr(upsample_block, "has_cross_attention")
|
||||
and upsample_block.has_cross_attention
|
||||
):
|
||||
sample, up_res_samples = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
@ -869,7 +996,7 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
return_res_samples=True
|
||||
return_res_samples=True,
|
||||
)
|
||||
else:
|
||||
sample, up_res_samples = upsample_block(
|
||||
@ -877,53 +1004,87 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
upsample_size=upsample_size,
|
||||
return_res_samples=True
|
||||
return_res_samples=True,
|
||||
)
|
||||
|
||||
up_block_res_samples += up_res_samples
|
||||
|
||||
# 8. BrushNet up blocks
|
||||
brushnet_up_block_res_samples = ()
|
||||
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
|
||||
for up_block_res_sample, brushnet_up_block in zip(
|
||||
up_block_res_samples, self.brushnet_up_blocks
|
||||
):
|
||||
up_block_res_sample = brushnet_up_block(up_block_res_sample)
|
||||
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
|
||||
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (
|
||||
up_block_res_sample,
|
||||
)
|
||||
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0,
|
||||
len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
|
||||
device=sample.device) # 0.1 to 1.0
|
||||
scales = torch.logspace(
|
||||
-1,
|
||||
0,
|
||||
len(brushnet_down_block_res_samples)
|
||||
+ 1
|
||||
+ len(brushnet_up_block_res_samples),
|
||||
device=sample.device,
|
||||
) # 0.1 to 1.0
|
||||
scales = scales * conditioning_scale
|
||||
|
||||
brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples,
|
||||
scales[:len(
|
||||
brushnet_down_block_res_samples)])]
|
||||
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
|
||||
brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples,
|
||||
scales[
|
||||
len(brushnet_down_block_res_samples) + 1:])]
|
||||
brushnet_down_block_res_samples = [
|
||||
sample * scale
|
||||
for sample, scale in zip(
|
||||
brushnet_down_block_res_samples,
|
||||
scales[: len(brushnet_down_block_res_samples)],
|
||||
)
|
||||
]
|
||||
brushnet_mid_block_res_sample = (
|
||||
brushnet_mid_block_res_sample
|
||||
* scales[len(brushnet_down_block_res_samples)]
|
||||
)
|
||||
brushnet_up_block_res_samples = [
|
||||
sample * scale
|
||||
for sample, scale in zip(
|
||||
brushnet_up_block_res_samples,
|
||||
scales[len(brushnet_down_block_res_samples) + 1 :],
|
||||
)
|
||||
]
|
||||
else:
|
||||
brushnet_down_block_res_samples = [sample * conditioning_scale for sample in
|
||||
brushnet_down_block_res_samples]
|
||||
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
|
||||
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
|
||||
brushnet_down_block_res_samples = [
|
||||
sample * conditioning_scale
|
||||
for sample in brushnet_down_block_res_samples
|
||||
]
|
||||
brushnet_mid_block_res_sample = (
|
||||
brushnet_mid_block_res_sample * conditioning_scale
|
||||
)
|
||||
brushnet_up_block_res_samples = [
|
||||
sample * conditioning_scale for sample in brushnet_up_block_res_samples
|
||||
]
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
brushnet_down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True)
|
||||
for sample in brushnet_down_block_res_samples
|
||||
]
|
||||
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
brushnet_mid_block_res_sample = torch.mean(
|
||||
brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True
|
||||
)
|
||||
brushnet_up_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True)
|
||||
for sample in brushnet_up_block_res_samples
|
||||
]
|
||||
|
||||
if not return_dict:
|
||||
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
|
||||
return (
|
||||
brushnet_down_block_res_samples,
|
||||
brushnet_mid_block_res_sample,
|
||||
brushnet_up_block_res_samples,
|
||||
)
|
||||
|
||||
return BrushNetOutput(
|
||||
down_block_res_samples=brushnet_down_block_res_samples,
|
||||
mid_block_res_sample=brushnet_mid_block_res_sample,
|
||||
up_block_res_samples=brushnet_up_block_res_samples
|
||||
up_block_res_samples=brushnet_up_block_res_samples,
|
||||
)
|
||||
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -122,9 +122,13 @@ class ModelInfo(BaseModel):
|
||||
@computed_field
|
||||
@property
|
||||
def support_powerpaint_v2(self) -> bool:
|
||||
return self.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
]
|
||||
return (
|
||||
self.model_type
|
||||
in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
]
|
||||
and self.name != POWERPAINT_NAME
|
||||
)
|
||||
|
||||
|
||||
class Choices(str, Enum):
|
||||
@ -215,7 +219,6 @@ class SDSampler(str, Enum):
|
||||
lcm = "LCM"
|
||||
|
||||
|
||||
|
||||
class PowerPaintTask(Choices):
|
||||
text_guided = "text-guided"
|
||||
context_aware = "context-aware"
|
||||
|
@ -59,6 +59,9 @@ const DiffusionOptions = () => {
|
||||
updateExtenderDirection,
|
||||
adjustMask,
|
||||
clearMask,
|
||||
updateEnablePowerPaintV2,
|
||||
updateEnableBrushNet,
|
||||
updateEnableControlnet,
|
||||
] = useStore((state) => [
|
||||
state.serverConfig.samplers,
|
||||
state.settings,
|
||||
@ -71,6 +74,9 @@ const DiffusionOptions = () => {
|
||||
state.updateExtenderDirection,
|
||||
state.adjustMask,
|
||||
state.clearMask,
|
||||
state.updateEnablePowerPaintV2,
|
||||
state.updateEnableBrushNet,
|
||||
state.updateEnableControlnet,
|
||||
])
|
||||
const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile)
|
||||
const negativePromptRef = useRef(null)
|
||||
@ -114,12 +120,8 @@ const DiffusionOptions = () => {
|
||||
return null
|
||||
}
|
||||
|
||||
let disable = settings.enableControlnet
|
||||
let toolTip =
|
||||
"BrushNet is a plug-and-play image inpainting model with decomposed dual-branch diffusion. It can be used to inpaint images by conditioning on a mask."
|
||||
if (disable) {
|
||||
toolTip = "ControlNet is enabled, BrushNet is disabled."
|
||||
}
|
||||
"BrushNet is a plug-and-play image inpainting model works on any SD1.5 base models."
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
@ -129,20 +131,19 @@ const DiffusionOptions = () => {
|
||||
text="BrushNet"
|
||||
url="https://github.com/TencentARC/BrushNet"
|
||||
toolTip={toolTip}
|
||||
disabled={disable}
|
||||
/>
|
||||
<Switch
|
||||
id="brushnet"
|
||||
checked={settings.enableBrushNet}
|
||||
onCheckedChange={(value) => {
|
||||
updateSettings({ enableBrushNet: value })
|
||||
updateEnableBrushNet(value)
|
||||
}}
|
||||
disabled={disable}
|
||||
/>
|
||||
</RowContainer>
|
||||
<RowContainer>
|
||||
{/* <RowContainer>
|
||||
<Slider
|
||||
defaultValue={[100]}
|
||||
className="w-[180px]"
|
||||
min={1}
|
||||
max={100}
|
||||
step={1}
|
||||
@ -155,14 +156,13 @@ const DiffusionOptions = () => {
|
||||
<NumberInput
|
||||
id="brushnet-weight"
|
||||
className="w-[60px] rounded-full"
|
||||
disabled={!settings.enableBrushNet || disable}
|
||||
numberValue={settings.brushnetConditioningScale}
|
||||
allowFloat={false}
|
||||
onNumberValueChange={(val) => {
|
||||
updateSettings({ brushnetConditioningScale: val })
|
||||
}}
|
||||
/>
|
||||
</RowContainer>
|
||||
</RowContainer> */}
|
||||
|
||||
<RowContainer>
|
||||
<Select
|
||||
@ -198,12 +198,8 @@ const DiffusionOptions = () => {
|
||||
return null
|
||||
}
|
||||
|
||||
let disable = settings.enableBrushNet
|
||||
let toolTip =
|
||||
"Using an additional conditioning image to control how an image is generated"
|
||||
if (disable) {
|
||||
toolTip = "BrushNet is enabled, ControlNet is disabled."
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
@ -213,15 +209,13 @@ const DiffusionOptions = () => {
|
||||
text="ControlNet"
|
||||
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inpaint#controlnet"
|
||||
toolTip={toolTip}
|
||||
disabled={disable}
|
||||
/>
|
||||
<Switch
|
||||
id="controlnet"
|
||||
checked={settings.enableControlnet}
|
||||
onCheckedChange={(value) => {
|
||||
updateSettings({ enableControlnet: value })
|
||||
updateEnableControlnet(value)
|
||||
}}
|
||||
disabled={disable}
|
||||
/>
|
||||
</RowContainer>
|
||||
|
||||
@ -233,7 +227,7 @@ const DiffusionOptions = () => {
|
||||
min={1}
|
||||
max={100}
|
||||
step={1}
|
||||
disabled={!settings.enableControlnet || disable}
|
||||
disabled={!settings.enableControlnet}
|
||||
value={[Math.floor(settings.controlnetConditioningScale * 100)]}
|
||||
onValueChange={(vals) =>
|
||||
updateSettings({ controlnetConditioningScale: vals[0] / 100 })
|
||||
@ -242,7 +236,7 @@ const DiffusionOptions = () => {
|
||||
<NumberInput
|
||||
id="controlnet-weight"
|
||||
className="w-[60px] rounded-full"
|
||||
disabled={!settings.enableControlnet || disable}
|
||||
disabled={!settings.enableControlnet}
|
||||
numberValue={settings.controlnetConditioningScale}
|
||||
allowFloat={false}
|
||||
onNumberValueChange={(val) => {
|
||||
@ -286,12 +280,8 @@ const DiffusionOptions = () => {
|
||||
return null
|
||||
}
|
||||
|
||||
let disable = settings.enableBrushNet
|
||||
let toolTip =
|
||||
"Enable quality image generation in typically 2-4 steps. Suggest disabling guidance_scale by setting it to 0. You can also try values between 1.0 and 2.0. When LCM Lora is enabled, LCMSampler will be used automatically."
|
||||
if (disable) {
|
||||
toolTip = "BrushNet is enabled, LCM Lora is disabled."
|
||||
}
|
||||
"Enable quality image generation in typically 2-8 steps. Suggest disabling guidance_scale by setting it to 0. You can also try values between 1.0 and 2.0. When LCM Lora is enabled, LCMSampler will be used automatically."
|
||||
|
||||
return (
|
||||
<>
|
||||
@ -300,7 +290,6 @@ const DiffusionOptions = () => {
|
||||
text="LCM LoRA"
|
||||
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm_lora"
|
||||
toolTip={toolTip}
|
||||
disabled={disable}
|
||||
/>
|
||||
<Switch
|
||||
id="lcm-lora"
|
||||
@ -308,7 +297,6 @@ const DiffusionOptions = () => {
|
||||
onCheckedChange={(value) => {
|
||||
updateSettings({ enableLCMLora: value })
|
||||
}}
|
||||
disabled={disable}
|
||||
/>
|
||||
</RowContainer>
|
||||
<Separator />
|
||||
@ -561,10 +549,6 @@ const DiffusionOptions = () => {
|
||||
}
|
||||
|
||||
const renderPowerPaintTaskType = () => {
|
||||
if (settings.model.name !== POWERPAINT) {
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<RowContainer>
|
||||
<LabelTitle
|
||||
@ -579,7 +563,7 @@ const DiffusionOptions = () => {
|
||||
}}
|
||||
disabled={settings.showExtender}
|
||||
>
|
||||
<SelectTrigger className="w-[140px]">
|
||||
<SelectTrigger className="w-[130px]">
|
||||
<SelectValue placeholder="Select task" />
|
||||
</SelectTrigger>
|
||||
<SelectContent align="end">
|
||||
@ -587,6 +571,7 @@ const DiffusionOptions = () => {
|
||||
{[
|
||||
PowerPaintTask.text_guided,
|
||||
PowerPaintTask.object_remove,
|
||||
PowerPaintTask.context_aware,
|
||||
PowerPaintTask.shape_guided,
|
||||
].map((task) => (
|
||||
<SelectItem key={task} value={task}>
|
||||
@ -600,6 +585,44 @@ const DiffusionOptions = () => {
|
||||
)
|
||||
}
|
||||
|
||||
const renderPowerPaintV1 = () => {
|
||||
if (settings.model.name !== POWERPAINT) {
|
||||
return null
|
||||
}
|
||||
return (
|
||||
<>
|
||||
{renderPowerPaintTaskType()}
|
||||
<Separator />
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
const renderPowerPaintV2 = () => {
|
||||
if (settings.model.support_powerpaint_v2 === false) {
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<RowContainer>
|
||||
<LabelTitle
|
||||
text="PowerPaint V2"
|
||||
toolTip="PowerPaint is a plug-and-play image inpainting model works on any SD1.5 base models."
|
||||
/>
|
||||
<Switch
|
||||
id="powerpaint-v2"
|
||||
checked={settings.enablePowerPaintV2}
|
||||
onCheckedChange={(value) => {
|
||||
updateEnablePowerPaintV2(value)
|
||||
}}
|
||||
/>
|
||||
</RowContainer>
|
||||
{renderPowerPaintTaskType()}
|
||||
<Separator />
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
const renderSteps = () => {
|
||||
return (
|
||||
<RowContainer>
|
||||
@ -868,7 +891,7 @@ const DiffusionOptions = () => {
|
||||
{renderMaskBlur()}
|
||||
{renderMaskAdjuster()}
|
||||
{renderMatchHistograms()}
|
||||
{renderPowerPaintTaskType()}
|
||||
{renderPowerPaintV1()}
|
||||
{renderSteps()}
|
||||
{renderGuidanceScale()}
|
||||
{renderP2PImageGuidanceScale()}
|
||||
@ -878,6 +901,7 @@ const DiffusionOptions = () => {
|
||||
{renderNegativePrompt()}
|
||||
<Separator />
|
||||
{renderBrushNetSetting()}
|
||||
{renderPowerPaintV2()}
|
||||
{renderConterNetSetting()}
|
||||
{renderLCMLora()}
|
||||
{renderPaintByExample()}
|
||||
|
@ -22,7 +22,7 @@ const SelectTrigger = React.forwardRef<
|
||||
<SelectPrimitive.Trigger
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent px-3 py-2 text-sm shadow-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-1 focus:ring-ring disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1",
|
||||
"flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent pl-2 pr-1 py-2 text-sm shadow-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-1 focus:ring-ring disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1",
|
||||
className
|
||||
)}
|
||||
tabIndex={-1}
|
||||
|
@ -79,6 +79,7 @@ export default async function inpaint(
|
||||
enable_brushnet: settings.enableBrushNet,
|
||||
brushnet_method: settings.brushnetMethod ? settings.brushnetMethod : "",
|
||||
brushnet_conditioning_scale: settings.brushnetConditioningScale,
|
||||
enable_powerpaint_v2: settings.enablePowerPaintV2,
|
||||
powerpaint_task: settings.showExtender
|
||||
? PowerPaintTask.outpainting
|
||||
: settings.powerpaintTask,
|
||||
|
@ -106,6 +106,7 @@ export type Settings = {
|
||||
enableLCMLora: boolean
|
||||
|
||||
// PowerPaint
|
||||
enablePowerPaintV2: boolean
|
||||
powerpaintTask: PowerPaintTask
|
||||
|
||||
// AdjustMask
|
||||
@ -194,6 +195,12 @@ type AppAction = {
|
||||
setServerConfig: (newValue: ServerConfig) => void
|
||||
setSeed: (newValue: number) => void
|
||||
updateSettings: (newSettings: Partial<Settings>) => void
|
||||
|
||||
// 互斥
|
||||
updateEnablePowerPaintV2: (newValue: boolean) => void
|
||||
updateEnableBrushNet: (newValue: boolean) => void
|
||||
updateEnableControlnet: (newValue: boolean) => void
|
||||
|
||||
setModel: (newModel: ModelInfo) => void
|
||||
updateFileManagerState: (newState: Partial<FileManagerState>) => void
|
||||
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
|
||||
@ -311,6 +318,7 @@ const defaultValues: AppState = {
|
||||
support_brushnet: false,
|
||||
support_strength: false,
|
||||
support_outpainting: false,
|
||||
support_powerpaint_v2: false,
|
||||
controlnets: [],
|
||||
brushnets: [],
|
||||
support_lcm_lora: false,
|
||||
@ -425,6 +433,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
if (
|
||||
get().settings.model.support_outpainting &&
|
||||
settings.showExtender &&
|
||||
extenderState.x === 0 &&
|
||||
extenderState.y === 0 &&
|
||||
extenderState.height === imageHeight &&
|
||||
extenderState.width === imageWidth
|
||||
) {
|
||||
@ -798,6 +808,38 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
})
|
||||
},
|
||||
|
||||
updateEnablePowerPaintV2: (newValue: boolean) => {
|
||||
get().updateSettings({ enablePowerPaintV2: newValue })
|
||||
if (newValue) {
|
||||
get().updateSettings({
|
||||
enableBrushNet: false,
|
||||
enableControlnet: false,
|
||||
enableLCMLora: false,
|
||||
})
|
||||
}
|
||||
},
|
||||
|
||||
updateEnableBrushNet: (newValue: boolean) => {
|
||||
get().updateSettings({ enableBrushNet: newValue })
|
||||
if (newValue) {
|
||||
get().updateSettings({
|
||||
enablePowerPaintV2: false,
|
||||
enableControlnet: false,
|
||||
enableLCMLora: false,
|
||||
})
|
||||
}
|
||||
},
|
||||
|
||||
updateEnableControlnet(newValue) {
|
||||
get().updateSettings({ enableControlnet: newValue })
|
||||
if (newValue) {
|
||||
get().updateSettings({
|
||||
enablePowerPaintV2: false,
|
||||
enableBrushNet: false,
|
||||
})
|
||||
}
|
||||
},
|
||||
|
||||
setModel: (newModel: ModelInfo) => {
|
||||
set((state) => {
|
||||
state.settings.model = newModel
|
||||
|
@ -49,6 +49,7 @@ export interface ModelInfo {
|
||||
support_outpainting: boolean
|
||||
support_controlnet: boolean
|
||||
support_brushnet: boolean
|
||||
support_powerpaint_v2: boolean
|
||||
controlnets: string[]
|
||||
brushnets: string[]
|
||||
support_lcm_lora: boolean
|
||||
@ -123,6 +124,7 @@ export enum ExtenderDirection {
|
||||
export enum PowerPaintTask {
|
||||
text_guided = "text-guided",
|
||||
shape_guided = "shape-guided",
|
||||
context_aware = "context-aware",
|
||||
object_remove = "object-remove",
|
||||
outpainting = "outpainting",
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user