update
This commit is contained in:
parent
017a3d68fd
commit
80ee1b9941
@ -1,6 +1,9 @@
|
|||||||
|
from itertools import chain
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
|
from iopaint.model.original_sd_configs import get_config_files
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from transformers import CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -14,9 +17,15 @@ from ..utils import (
|
|||||||
handle_from_pretrained_exceptions,
|
handle_from_pretrained_exceptions,
|
||||||
)
|
)
|
||||||
from .powerpaint_tokenizer import task_to_prompt
|
from .powerpaint_tokenizer import task_to_prompt
|
||||||
from iopaint.schema import InpaintRequest
|
from iopaint.schema import InpaintRequest, ModelType
|
||||||
from .v2.BrushNet_CA import BrushNetModel
|
from .v2.BrushNet_CA import BrushNetModel
|
||||||
from .v2.unet_2d_condition import UNet2DConditionModel
|
from .v2.unet_2d_condition import UNet2DConditionModel_forward
|
||||||
|
from .v2.unet_2d_blocks import (
|
||||||
|
CrossAttnDownBlock2D_forward,
|
||||||
|
DownBlock2D_forward,
|
||||||
|
CrossAttnUpBlock2D_forward,
|
||||||
|
UpBlock2D_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PowerPaintV2(DiffusionInpaintModel):
|
class PowerPaintV2(DiffusionInpaintModel):
|
||||||
@ -50,14 +59,7 @@ class PowerPaintV2(DiffusionInpaintModel):
|
|||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
local_files_only=model_kwargs["local_files_only"],
|
local_files_only=model_kwargs["local_files_only"],
|
||||||
)
|
)
|
||||||
unet = handle_from_pretrained_exceptions(
|
|
||||||
UNet2DConditionModel.from_pretrained,
|
|
||||||
pretrained_model_name_or_path=self.model_id_or_path,
|
|
||||||
subfolder="unet",
|
|
||||||
variant="fp16",
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
local_files_only=model_kwargs["local_files_only"],
|
|
||||||
)
|
|
||||||
brushnet = BrushNetModel.from_pretrained(
|
brushnet = BrushNetModel.from_pretrained(
|
||||||
self.hf_model_id,
|
self.hf_model_id,
|
||||||
subfolder="PowerPaint_Brushnet",
|
subfolder="PowerPaint_Brushnet",
|
||||||
@ -65,16 +67,32 @@ class PowerPaintV2(DiffusionInpaintModel):
|
|||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
local_files_only=model_kwargs["local_files_only"],
|
local_files_only=model_kwargs["local_files_only"],
|
||||||
)
|
)
|
||||||
pipe = handle_from_pretrained_exceptions(
|
|
||||||
StableDiffusionPowerPaintBrushNetPipeline.from_pretrained,
|
if self.model_info.is_single_file_diffusers:
|
||||||
pretrained_model_name_or_path=self.model_id_or_path,
|
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||||
torch_dtype=torch_dtype,
|
model_kwargs["num_in_channels"] = 4
|
||||||
unet=unet,
|
else:
|
||||||
brushnet=brushnet,
|
model_kwargs["num_in_channels"] = 9
|
||||||
text_encoder_brushnet=text_encoder_brushnet,
|
|
||||||
variant="fp16",
|
pipe = StableDiffusionPowerPaintBrushNetPipeline.from_single_file(
|
||||||
**model_kwargs,
|
self.model_id_or_path,
|
||||||
)
|
torch_dtype=torch_dtype,
|
||||||
|
load_safety_checker=False,
|
||||||
|
original_config_file=get_config_files()["v1"],
|
||||||
|
brushnet=brushnet,
|
||||||
|
text_encoder_brushnet=text_encoder_brushnet,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pipe = handle_from_pretrained_exceptions(
|
||||||
|
StableDiffusionPowerPaintBrushNetPipeline.from_pretrained,
|
||||||
|
pretrained_model_name_or_path=self.model_id_or_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
brushnet=brushnet,
|
||||||
|
text_encoder_brushnet=text_encoder_brushnet,
|
||||||
|
variant="fp16",
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
pipe.tokenizer = PowerPaintTokenizer(
|
pipe.tokenizer = PowerPaintTokenizer(
|
||||||
CLIPTokenizer.from_pretrained(self.hf_model_id, subfolder="tokenizer")
|
CLIPTokenizer.from_pretrained(self.hf_model_id, subfolder="tokenizer")
|
||||||
)
|
)
|
||||||
@ -95,6 +113,34 @@ class PowerPaintV2(DiffusionInpaintModel):
|
|||||||
|
|
||||||
self.callback = kwargs.pop("callback", None)
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
|
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
|
||||||
|
self.model.unet.forward = UNet2DConditionModel_forward.__get__(
|
||||||
|
self.model.unet, self.model.unet.__class__
|
||||||
|
)
|
||||||
|
|
||||||
|
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
|
||||||
|
for down_block in chain(
|
||||||
|
self.model.unet.down_blocks, self.model.brushnet.down_blocks
|
||||||
|
):
|
||||||
|
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
|
||||||
|
down_block.forward = CrossAttnDownBlock2D_forward.__get__(
|
||||||
|
down_block, down_block.__class__
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
down_block.forward = DownBlock2D_forward.__get__(
|
||||||
|
down_block, down_block.__class__
|
||||||
|
)
|
||||||
|
|
||||||
|
for up_block in chain(self.model.unet.up_blocks, self.model.brushnet.up_blocks):
|
||||||
|
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
||||||
|
up_block.forward = CrossAttnUpBlock2D_forward.__get__(
|
||||||
|
up_block, up_block.__class__
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
up_block.forward = UpBlock2D_forward.__get__(
|
||||||
|
up_block, up_block.__class__
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, image, mask, config: InpaintRequest):
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
"""Input image and output image have same size
|
"""Input image and output image have same size
|
||||||
image: [H, W, C] RGB
|
image: [H, W, C] RGB
|
||||||
@ -129,11 +175,10 @@ class PowerPaintV2(DiffusionInpaintModel):
|
|||||||
brushnet_conditioning_scale=1.0,
|
brushnet_conditioning_scale=1.0,
|
||||||
guidance_scale=config.sd_guidance_scale,
|
guidance_scale=config.sd_guidance_scale,
|
||||||
output_type="np",
|
output_type="np",
|
||||||
callback=self.callback,
|
callback_on_step_end=self.callback,
|
||||||
height=img_h,
|
height=img_h,
|
||||||
width=img_w,
|
width=img_w,
|
||||||
generator=torch.manual_seed(config.sd_seed),
|
generator=torch.manual_seed(config.sd_seed),
|
||||||
callback_steps=1,
|
|
||||||
).images[0]
|
).images[0]
|
||||||
|
|
||||||
output = (output * 255).round().astype("uint8")
|
output = (output * 255).round().astype("uint8")
|
||||||
|
@ -2,6 +2,14 @@ from dataclasses import dataclass
|
|||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers import UNet2DConditionModel
|
||||||
|
from diffusers.models.unet_2d_blocks import (
|
||||||
|
get_down_block,
|
||||||
|
get_mid_block,
|
||||||
|
get_up_block,
|
||||||
|
CrossAttnDownBlock2D,
|
||||||
|
DownBlock2D,
|
||||||
|
)
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
@ -13,18 +21,14 @@ from diffusers.models.attention_processor import (
|
|||||||
AttnAddedKVProcessor,
|
AttnAddedKVProcessor,
|
||||||
AttnProcessor,
|
AttnProcessor,
|
||||||
)
|
)
|
||||||
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, \
|
from diffusers.models.embeddings import (
|
||||||
TimestepEmbedding, Timesteps
|
TextImageProjection,
|
||||||
from diffusers.models.modeling_utils import ModelMixin
|
TextImageTimeEmbedding,
|
||||||
from .unet_2d_blocks import (
|
TextTimeEmbedding,
|
||||||
CrossAttnDownBlock2D,
|
TimestepEmbedding,
|
||||||
DownBlock2D,
|
Timesteps,
|
||||||
get_down_block,
|
|
||||||
get_mid_block,
|
|
||||||
get_up_block
|
|
||||||
)
|
)
|
||||||
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
from .unet_2d_condition import UNet2DConditionModel
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
@ -132,47 +136,55 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_channels: int = 4,
|
in_channels: int = 4,
|
||||||
conditioning_channels: int = 5,
|
conditioning_channels: int = 5,
|
||||||
flip_sin_to_cos: bool = True,
|
flip_sin_to_cos: bool = True,
|
||||||
freq_shift: int = 0,
|
freq_shift: int = 0,
|
||||||
down_block_types: Tuple[str, ...] = (
|
down_block_types: Tuple[str, ...] = (
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"DownBlock2D",
|
"DownBlock2D",
|
||||||
),
|
),
|
||||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||||
up_block_types: Tuple[str, ...] = (
|
up_block_types: Tuple[str, ...] = (
|
||||||
"UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"
|
"UpBlock2D",
|
||||||
),
|
"CrossAttnUpBlock2D",
|
||||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
"CrossAttnUpBlock2D",
|
||||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
"CrossAttnUpBlock2D",
|
||||||
layers_per_block: int = 2,
|
),
|
||||||
downsample_padding: int = 1,
|
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||||
mid_block_scale_factor: float = 1,
|
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||||
act_fn: str = "silu",
|
layers_per_block: int = 2,
|
||||||
norm_num_groups: Optional[int] = 32,
|
downsample_padding: int = 1,
|
||||||
norm_eps: float = 1e-5,
|
mid_block_scale_factor: float = 1,
|
||||||
cross_attention_dim: int = 1280,
|
act_fn: str = "silu",
|
||||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
norm_num_groups: Optional[int] = 32,
|
||||||
encoder_hid_dim: Optional[int] = None,
|
norm_eps: float = 1e-5,
|
||||||
encoder_hid_dim_type: Optional[str] = None,
|
cross_attention_dim: int = 1280,
|
||||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
encoder_hid_dim: Optional[int] = None,
|
||||||
use_linear_projection: bool = False,
|
encoder_hid_dim_type: Optional[str] = None,
|
||||||
class_embed_type: Optional[str] = None,
|
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||||
addition_embed_type: Optional[str] = None,
|
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||||
addition_time_embed_dim: Optional[int] = None,
|
use_linear_projection: bool = False,
|
||||||
num_class_embeds: Optional[int] = None,
|
class_embed_type: Optional[str] = None,
|
||||||
upcast_attention: bool = False,
|
addition_embed_type: Optional[str] = None,
|
||||||
resnet_time_scale_shift: str = "default",
|
addition_time_embed_dim: Optional[int] = None,
|
||||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
num_class_embeds: Optional[int] = None,
|
||||||
brushnet_conditioning_channel_order: str = "rgb",
|
upcast_attention: bool = False,
|
||||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
resnet_time_scale_shift: str = "default",
|
||||||
global_pool_conditions: bool = False,
|
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||||
addition_embed_type_num_heads: int = 64,
|
brushnet_conditioning_channel_order: str = "rgb",
|
||||||
|
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
|
||||||
|
16,
|
||||||
|
32,
|
||||||
|
96,
|
||||||
|
256,
|
||||||
|
),
|
||||||
|
global_pool_conditions: bool = False,
|
||||||
|
addition_embed_type_num_heads: int = 64,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -195,25 +207,33 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
if not isinstance(only_cross_attention, bool) and len(
|
||||||
|
only_cross_attention
|
||||||
|
) != len(down_block_types):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||||
)
|
)
|
||||||
|
|
||||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
|
||||||
|
down_block_types
|
||||||
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(transformer_layers_per_block, int):
|
if isinstance(transformer_layers_per_block, int):
|
||||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
transformer_layers_per_block = [transformer_layers_per_block] * len(
|
||||||
|
down_block_types
|
||||||
|
)
|
||||||
|
|
||||||
# input
|
# input
|
||||||
conv_in_kernel = 3
|
conv_in_kernel = 3
|
||||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||||
self.conv_in_condition = nn.Conv2d(
|
self.conv_in_condition = nn.Conv2d(
|
||||||
in_channels + conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel,
|
in_channels + conditioning_channels,
|
||||||
padding=conv_in_padding
|
block_out_channels[0],
|
||||||
|
kernel_size=conv_in_kernel,
|
||||||
|
padding=conv_in_padding,
|
||||||
)
|
)
|
||||||
|
|
||||||
# time
|
# time
|
||||||
@ -229,7 +249,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
||||||
encoder_hid_dim_type = "text_proj"
|
encoder_hid_dim_type = "text_proj"
|
||||||
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
||||||
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
logger.info(
|
||||||
|
"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
|
||||||
|
)
|
||||||
|
|
||||||
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -274,7 +296,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
||||||
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
||||||
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
||||||
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
self.class_embedding = TimestepEmbedding(
|
||||||
|
projection_class_embeddings_input_dim, time_embed_dim
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.class_embedding = None
|
self.class_embedding = None
|
||||||
|
|
||||||
@ -285,21 +309,31 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
text_time_embedding_from_dim = cross_attention_dim
|
text_time_embedding_from_dim = cross_attention_dim
|
||||||
|
|
||||||
self.add_embedding = TextTimeEmbedding(
|
self.add_embedding = TextTimeEmbedding(
|
||||||
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
text_time_embedding_from_dim,
|
||||||
|
time_embed_dim,
|
||||||
|
num_heads=addition_embed_type_num_heads,
|
||||||
)
|
)
|
||||||
elif addition_embed_type == "text_image":
|
elif addition_embed_type == "text_image":
|
||||||
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||||
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
||||||
self.add_embedding = TextImageTimeEmbedding(
|
self.add_embedding = TextImageTimeEmbedding(
|
||||||
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
text_embed_dim=cross_attention_dim,
|
||||||
|
image_embed_dim=cross_attention_dim,
|
||||||
|
time_embed_dim=time_embed_dim,
|
||||||
)
|
)
|
||||||
elif addition_embed_type == "text_time":
|
elif addition_embed_type == "text_time":
|
||||||
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
self.add_time_proj = Timesteps(
|
||||||
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
addition_time_embed_dim, flip_sin_to_cos, freq_shift
|
||||||
|
)
|
||||||
|
self.add_embedding = TimestepEmbedding(
|
||||||
|
projection_class_embeddings_input_dim, time_embed_dim
|
||||||
|
)
|
||||||
|
|
||||||
elif addition_embed_type is not None:
|
elif addition_embed_type is not None:
|
||||||
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
raise ValueError(
|
||||||
|
f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
|
||||||
|
)
|
||||||
|
|
||||||
self.down_blocks = nn.ModuleList([])
|
self.down_blocks = nn.ModuleList([])
|
||||||
self.brushnet_down_blocks = nn.ModuleList([])
|
self.brushnet_down_blocks = nn.ModuleList([])
|
||||||
@ -338,7 +372,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
resnet_groups=norm_num_groups,
|
resnet_groups=norm_num_groups,
|
||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
num_attention_heads=num_attention_heads[i],
|
num_attention_heads=num_attention_heads[i],
|
||||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
attention_head_dim=attention_head_dim[i]
|
||||||
|
if attention_head_dim[i] is not None
|
||||||
|
else output_channel,
|
||||||
downsample_padding=downsample_padding,
|
downsample_padding=downsample_padding,
|
||||||
use_linear_projection=use_linear_projection,
|
use_linear_projection=use_linear_projection,
|
||||||
only_cross_attention=only_cross_attention[i],
|
only_cross_attention=only_cross_attention[i],
|
||||||
@ -348,12 +384,16 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
self.down_blocks.append(down_block)
|
self.down_blocks.append(down_block)
|
||||||
|
|
||||||
for _ in range(layers_per_block):
|
for _ in range(layers_per_block):
|
||||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
brushnet_block = nn.Conv2d(
|
||||||
|
output_channel, output_channel, kernel_size=1
|
||||||
|
)
|
||||||
brushnet_block = zero_module(brushnet_block)
|
brushnet_block = zero_module(brushnet_block)
|
||||||
self.brushnet_down_blocks.append(brushnet_block)
|
self.brushnet_down_blocks.append(brushnet_block)
|
||||||
|
|
||||||
if not is_final_block:
|
if not is_final_block:
|
||||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
brushnet_block = nn.Conv2d(
|
||||||
|
output_channel, output_channel, kernel_size=1
|
||||||
|
)
|
||||||
brushnet_block = zero_module(brushnet_block)
|
brushnet_block = zero_module(brushnet_block)
|
||||||
self.brushnet_down_blocks.append(brushnet_block)
|
self.brushnet_down_blocks.append(brushnet_block)
|
||||||
|
|
||||||
@ -386,7 +426,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
# up
|
# up
|
||||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||||
reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
|
reversed_transformer_layers_per_block = list(
|
||||||
|
reversed(transformer_layers_per_block)
|
||||||
|
)
|
||||||
only_cross_attention = list(reversed(only_cross_attention))
|
only_cross_attention = list(reversed(only_cross_attention))
|
||||||
|
|
||||||
output_channel = reversed_block_out_channels[0]
|
output_channel = reversed_block_out_channels[0]
|
||||||
@ -399,7 +441,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
prev_output_channel = output_channel
|
prev_output_channel = output_channel
|
||||||
output_channel = reversed_block_out_channels[i]
|
output_channel = reversed_block_out_channels[i]
|
||||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
input_channel = reversed_block_out_channels[
|
||||||
|
min(i + 1, len(block_out_channels) - 1)
|
||||||
|
]
|
||||||
|
|
||||||
# add upsample block for all BUT final layer
|
# add upsample block for all BUT final layer
|
||||||
if not is_final_block:
|
if not is_final_block:
|
||||||
@ -427,29 +471,40 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
only_cross_attention=only_cross_attention[i],
|
only_cross_attention=only_cross_attention[i],
|
||||||
upcast_attention=upcast_attention,
|
upcast_attention=upcast_attention,
|
||||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
attention_head_dim=attention_head_dim[i]
|
||||||
|
if attention_head_dim[i] is not None
|
||||||
|
else output_channel,
|
||||||
)
|
)
|
||||||
self.up_blocks.append(up_block)
|
self.up_blocks.append(up_block)
|
||||||
prev_output_channel = output_channel
|
prev_output_channel = output_channel
|
||||||
|
|
||||||
for _ in range(layers_per_block + 1):
|
for _ in range(layers_per_block + 1):
|
||||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
brushnet_block = nn.Conv2d(
|
||||||
|
output_channel, output_channel, kernel_size=1
|
||||||
|
)
|
||||||
brushnet_block = zero_module(brushnet_block)
|
brushnet_block = zero_module(brushnet_block)
|
||||||
self.brushnet_up_blocks.append(brushnet_block)
|
self.brushnet_up_blocks.append(brushnet_block)
|
||||||
|
|
||||||
if not is_final_block:
|
if not is_final_block:
|
||||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
brushnet_block = nn.Conv2d(
|
||||||
|
output_channel, output_channel, kernel_size=1
|
||||||
|
)
|
||||||
brushnet_block = zero_module(brushnet_block)
|
brushnet_block = zero_module(brushnet_block)
|
||||||
self.brushnet_up_blocks.append(brushnet_block)
|
self.brushnet_up_blocks.append(brushnet_block)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_unet(
|
def from_unet(
|
||||||
cls,
|
cls,
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
brushnet_conditioning_channel_order: str = "rgb",
|
brushnet_conditioning_channel_order: str = "rgb",
|
||||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
|
||||||
load_weights_from_unet: bool = True,
|
16,
|
||||||
conditioning_channels: int = 5,
|
32,
|
||||||
|
96,
|
||||||
|
256,
|
||||||
|
),
|
||||||
|
load_weights_from_unet: bool = True,
|
||||||
|
conditioning_channels: int = 5,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
|
Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
|
||||||
@ -460,13 +515,27 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
where applicable.
|
where applicable.
|
||||||
"""
|
"""
|
||||||
transformer_layers_per_block = (
|
transformer_layers_per_block = (
|
||||||
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
unet.config.transformer_layers_per_block
|
||||||
|
if "transformer_layers_per_block" in unet.config
|
||||||
|
else 1
|
||||||
|
)
|
||||||
|
encoder_hid_dim = (
|
||||||
|
unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
||||||
|
)
|
||||||
|
encoder_hid_dim_type = (
|
||||||
|
unet.config.encoder_hid_dim_type
|
||||||
|
if "encoder_hid_dim_type" in unet.config
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
addition_embed_type = (
|
||||||
|
unet.config.addition_embed_type
|
||||||
|
if "addition_embed_type" in unet.config
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
|
||||||
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
|
||||||
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
|
||||||
addition_time_embed_dim = (
|
addition_time_embed_dim = (
|
||||||
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
unet.config.addition_time_embed_dim
|
||||||
|
if "addition_time_embed_dim" in unet.config
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
brushnet = cls(
|
brushnet = cls(
|
||||||
@ -475,14 +544,21 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||||
freq_shift=unet.config.freq_shift,
|
freq_shift=unet.config.freq_shift,
|
||||||
# down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
|
# down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
|
||||||
down_block_types=["CrossAttnDownBlock2D",
|
down_block_types=[
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"CrossAttnDownBlock2D",
|
"CrossAttnDownBlock2D",
|
||||||
"DownBlock2D", ],
|
"CrossAttnDownBlock2D",
|
||||||
|
"DownBlock2D",
|
||||||
|
],
|
||||||
# mid_block_type='MidBlock2D',
|
# mid_block_type='MidBlock2D',
|
||||||
mid_block_type="UNetMidBlock2DCrossAttn",
|
mid_block_type="UNetMidBlock2DCrossAttn",
|
||||||
# up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
|
# up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
|
||||||
up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
|
up_block_types=[
|
||||||
|
"UpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
],
|
||||||
only_cross_attention=unet.config.only_cross_attention,
|
only_cross_attention=unet.config.only_cross_attention,
|
||||||
block_out_channels=unet.config.block_out_channels,
|
block_out_channels=unet.config.block_out_channels,
|
||||||
layers_per_block=unet.config.layers_per_block,
|
layers_per_block=unet.config.layers_per_block,
|
||||||
@ -510,21 +586,33 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if load_weights_from_unet:
|
if load_weights_from_unet:
|
||||||
conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
|
conv_in_condition_weight = torch.zeros_like(
|
||||||
|
brushnet.conv_in_condition.weight
|
||||||
|
)
|
||||||
conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
|
conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
|
||||||
conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
|
conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
|
||||||
brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
|
brushnet.conv_in_condition.weight = torch.nn.Parameter(
|
||||||
|
conv_in_condition_weight
|
||||||
|
)
|
||||||
brushnet.conv_in_condition.bias = unet.conv_in.bias
|
brushnet.conv_in_condition.bias = unet.conv_in.bias
|
||||||
|
|
||||||
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||||
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
||||||
|
|
||||||
if brushnet.class_embedding:
|
if brushnet.class_embedding:
|
||||||
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
brushnet.class_embedding.load_state_dict(
|
||||||
|
unet.class_embedding.state_dict()
|
||||||
|
)
|
||||||
|
|
||||||
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
|
brushnet.down_blocks.load_state_dict(
|
||||||
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
|
unet.down_blocks.state_dict(), strict=False
|
||||||
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
|
)
|
||||||
|
brushnet.mid_block.load_state_dict(
|
||||||
|
unet.mid_block.state_dict(), strict=False
|
||||||
|
)
|
||||||
|
brushnet.up_blocks.load_state_dict(
|
||||||
|
unet.up_blocks.state_dict(), strict=False
|
||||||
|
)
|
||||||
|
|
||||||
return brushnet.to(unet.dtype)
|
return brushnet.to(unet.dtype)
|
||||||
|
|
||||||
@ -539,9 +627,15 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
# set recursively
|
# set recursively
|
||||||
processors = {}
|
processors = {}
|
||||||
|
|
||||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
def fn_recursive_add_processors(
|
||||||
|
name: str,
|
||||||
|
module: torch.nn.Module,
|
||||||
|
processors: Dict[str, AttentionProcessor],
|
||||||
|
):
|
||||||
if hasattr(module, "get_processor"):
|
if hasattr(module, "get_processor"):
|
||||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
processors[f"{name}.processor"] = module.get_processor(
|
||||||
|
return_deprecated_lora=True
|
||||||
|
)
|
||||||
|
|
||||||
for sub_name, child in module.named_children():
|
for sub_name, child in module.named_children():
|
||||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
@ -554,7 +648,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
return processors
|
return processors
|
||||||
|
|
||||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
def set_attn_processor(
|
||||||
|
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
Sets the attention processor to use to compute attention.
|
Sets the attention processor to use to compute attention.
|
||||||
|
|
||||||
@ -593,9 +689,15 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
"""
|
"""
|
||||||
Disables custom attention processors and sets the default attention implementation.
|
Disables custom attention processors and sets the default attention implementation.
|
||||||
"""
|
"""
|
||||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
if all(
|
||||||
|
proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
|
||||||
|
for proc in self.attn_processors.values()
|
||||||
|
):
|
||||||
processor = AttnAddedKVProcessor()
|
processor = AttnAddedKVProcessor()
|
||||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
elif all(
|
||||||
|
proc.__class__ in CROSS_ATTENTION_PROCESSORS
|
||||||
|
for proc in self.attn_processors.values()
|
||||||
|
):
|
||||||
processor = AttnProcessor()
|
processor = AttnProcessor()
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -642,7 +744,11 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
# make smallest slice possible
|
# make smallest slice possible
|
||||||
slice_size = num_sliceable_layers * [1]
|
slice_size = num_sliceable_layers * [1]
|
||||||
|
|
||||||
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
slice_size = (
|
||||||
|
num_sliceable_layers * [slice_size]
|
||||||
|
if not isinstance(slice_size, list)
|
||||||
|
else slice_size
|
||||||
|
)
|
||||||
|
|
||||||
if len(slice_size) != len(sliceable_head_dims):
|
if len(slice_size) != len(sliceable_head_dims):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -659,7 +765,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
# Recursively walk through all the children.
|
# Recursively walk through all the children.
|
||||||
# Any children which exposes the set_attention_slice method
|
# Any children which exposes the set_attention_slice method
|
||||||
# gets the message
|
# gets the message
|
||||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
def fn_recursive_set_attention_slice(
|
||||||
|
module: torch.nn.Module, slice_size: List[int]
|
||||||
|
):
|
||||||
if hasattr(module, "set_attention_slice"):
|
if hasattr(module, "set_attention_slice"):
|
||||||
module.set_attention_slice(slice_size.pop())
|
module.set_attention_slice(slice_size.pop())
|
||||||
|
|
||||||
@ -675,19 +783,19 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
module.gradient_checkpointing = value
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
sample: torch.FloatTensor,
|
sample: torch.FloatTensor,
|
||||||
timestep: Union[torch.Tensor, float, int],
|
timestep: Union[torch.Tensor, float, int],
|
||||||
encoder_hidden_states: torch.Tensor,
|
encoder_hidden_states: torch.Tensor,
|
||||||
brushnet_cond: torch.FloatTensor,
|
brushnet_cond: torch.FloatTensor,
|
||||||
conditioning_scale: float = 1.0,
|
conditioning_scale: float = 1.0,
|
||||||
class_labels: Optional[torch.Tensor] = None,
|
class_labels: Optional[torch.Tensor] = None,
|
||||||
timestep_cond: Optional[torch.Tensor] = None,
|
timestep_cond: Optional[torch.Tensor] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
guess_mode: bool = False,
|
guess_mode: bool = False,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
|
) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
|
||||||
"""
|
"""
|
||||||
The [`BrushNetModel`] forward method.
|
The [`BrushNetModel`] forward method.
|
||||||
@ -737,7 +845,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
elif channel_order == "bgr":
|
elif channel_order == "bgr":
|
||||||
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
|
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
|
raise ValueError(
|
||||||
|
f"unknown `brushnet_conditioning_channel_order`: {channel_order}"
|
||||||
|
)
|
||||||
|
|
||||||
# prepare attention_mask
|
# prepare attention_mask
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
@ -773,7 +883,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
if self.class_embedding is not None:
|
if self.class_embedding is not None:
|
||||||
if class_labels is None:
|
if class_labels is None:
|
||||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
raise ValueError(
|
||||||
|
"class_labels should be provided when num_class_embeds > 0"
|
||||||
|
)
|
||||||
|
|
||||||
if self.config.class_embed_type == "timestep":
|
if self.config.class_embed_type == "timestep":
|
||||||
class_labels = self.time_proj(class_labels)
|
class_labels = self.time_proj(class_labels)
|
||||||
@ -812,7 +924,10 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
# 3. down
|
# 3. down
|
||||||
down_block_res_samples = (sample,)
|
down_block_res_samples = (sample,)
|
||||||
for downsample_block in self.down_blocks:
|
for downsample_block in self.down_blocks:
|
||||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
if (
|
||||||
|
hasattr(downsample_block, "has_cross_attention")
|
||||||
|
and downsample_block.has_cross_attention
|
||||||
|
):
|
||||||
sample, res_samples = downsample_block(
|
sample, res_samples = downsample_block(
|
||||||
hidden_states=sample,
|
hidden_states=sample,
|
||||||
temb=emb,
|
temb=emb,
|
||||||
@ -827,13 +942,20 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
|
|
||||||
# 4. PaintingNet down blocks
|
# 4. PaintingNet down blocks
|
||||||
brushnet_down_block_res_samples = ()
|
brushnet_down_block_res_samples = ()
|
||||||
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
|
for down_block_res_sample, brushnet_down_block in zip(
|
||||||
|
down_block_res_samples, self.brushnet_down_blocks
|
||||||
|
):
|
||||||
down_block_res_sample = brushnet_down_block(down_block_res_sample)
|
down_block_res_sample = brushnet_down_block(down_block_res_sample)
|
||||||
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
|
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (
|
||||||
|
down_block_res_sample,
|
||||||
|
)
|
||||||
|
|
||||||
# 5. mid
|
# 5. mid
|
||||||
if self.mid_block is not None:
|
if self.mid_block is not None:
|
||||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
if (
|
||||||
|
hasattr(self.mid_block, "has_cross_attention")
|
||||||
|
and self.mid_block.has_cross_attention
|
||||||
|
):
|
||||||
sample = self.mid_block(
|
sample = self.mid_block(
|
||||||
sample,
|
sample,
|
||||||
emb,
|
emb,
|
||||||
@ -852,15 +974,20 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
for i, upsample_block in enumerate(self.up_blocks):
|
for i, upsample_block in enumerate(self.up_blocks):
|
||||||
is_final_block = i == len(self.up_blocks) - 1
|
is_final_block = i == len(self.up_blocks) - 1
|
||||||
|
|
||||||
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
down_block_res_samples = down_block_res_samples[
|
||||||
|
: -len(upsample_block.resnets)
|
||||||
|
]
|
||||||
|
|
||||||
# if we have not reached the final block and need to forward the
|
# if we have not reached the final block and need to forward the
|
||||||
# upsample size, we do it here
|
# upsample size, we do it here
|
||||||
if not is_final_block:
|
if not is_final_block:
|
||||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||||
|
|
||||||
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
if (
|
||||||
|
hasattr(upsample_block, "has_cross_attention")
|
||||||
|
and upsample_block.has_cross_attention
|
||||||
|
):
|
||||||
sample, up_res_samples = upsample_block(
|
sample, up_res_samples = upsample_block(
|
||||||
hidden_states=sample,
|
hidden_states=sample,
|
||||||
temb=emb,
|
temb=emb,
|
||||||
@ -869,7 +996,7 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
upsample_size=upsample_size,
|
upsample_size=upsample_size,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
return_res_samples=True
|
return_res_samples=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample, up_res_samples = upsample_block(
|
sample, up_res_samples = upsample_block(
|
||||||
@ -877,53 +1004,87 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
|||||||
temb=emb,
|
temb=emb,
|
||||||
res_hidden_states_tuple=res_samples,
|
res_hidden_states_tuple=res_samples,
|
||||||
upsample_size=upsample_size,
|
upsample_size=upsample_size,
|
||||||
return_res_samples=True
|
return_res_samples=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
up_block_res_samples += up_res_samples
|
up_block_res_samples += up_res_samples
|
||||||
|
|
||||||
# 8. BrushNet up blocks
|
# 8. BrushNet up blocks
|
||||||
brushnet_up_block_res_samples = ()
|
brushnet_up_block_res_samples = ()
|
||||||
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
|
for up_block_res_sample, brushnet_up_block in zip(
|
||||||
|
up_block_res_samples, self.brushnet_up_blocks
|
||||||
|
):
|
||||||
up_block_res_sample = brushnet_up_block(up_block_res_sample)
|
up_block_res_sample = brushnet_up_block(up_block_res_sample)
|
||||||
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
|
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (
|
||||||
|
up_block_res_sample,
|
||||||
|
)
|
||||||
|
|
||||||
# 6. scaling
|
# 6. scaling
|
||||||
if guess_mode and not self.config.global_pool_conditions:
|
if guess_mode and not self.config.global_pool_conditions:
|
||||||
scales = torch.logspace(-1, 0,
|
scales = torch.logspace(
|
||||||
len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
|
-1,
|
||||||
device=sample.device) # 0.1 to 1.0
|
0,
|
||||||
|
len(brushnet_down_block_res_samples)
|
||||||
|
+ 1
|
||||||
|
+ len(brushnet_up_block_res_samples),
|
||||||
|
device=sample.device,
|
||||||
|
) # 0.1 to 1.0
|
||||||
scales = scales * conditioning_scale
|
scales = scales * conditioning_scale
|
||||||
|
|
||||||
brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples,
|
brushnet_down_block_res_samples = [
|
||||||
scales[:len(
|
sample * scale
|
||||||
brushnet_down_block_res_samples)])]
|
for sample, scale in zip(
|
||||||
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
|
brushnet_down_block_res_samples,
|
||||||
brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples,
|
scales[: len(brushnet_down_block_res_samples)],
|
||||||
scales[
|
)
|
||||||
len(brushnet_down_block_res_samples) + 1:])]
|
]
|
||||||
|
brushnet_mid_block_res_sample = (
|
||||||
|
brushnet_mid_block_res_sample
|
||||||
|
* scales[len(brushnet_down_block_res_samples)]
|
||||||
|
)
|
||||||
|
brushnet_up_block_res_samples = [
|
||||||
|
sample * scale
|
||||||
|
for sample, scale in zip(
|
||||||
|
brushnet_up_block_res_samples,
|
||||||
|
scales[len(brushnet_down_block_res_samples) + 1 :],
|
||||||
|
)
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
brushnet_down_block_res_samples = [sample * conditioning_scale for sample in
|
brushnet_down_block_res_samples = [
|
||||||
brushnet_down_block_res_samples]
|
sample * conditioning_scale
|
||||||
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
|
for sample in brushnet_down_block_res_samples
|
||||||
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
|
]
|
||||||
|
brushnet_mid_block_res_sample = (
|
||||||
|
brushnet_mid_block_res_sample * conditioning_scale
|
||||||
|
)
|
||||||
|
brushnet_up_block_res_samples = [
|
||||||
|
sample * conditioning_scale for sample in brushnet_up_block_res_samples
|
||||||
|
]
|
||||||
|
|
||||||
if self.config.global_pool_conditions:
|
if self.config.global_pool_conditions:
|
||||||
brushnet_down_block_res_samples = [
|
brushnet_down_block_res_samples = [
|
||||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
|
torch.mean(sample, dim=(2, 3), keepdim=True)
|
||||||
|
for sample in brushnet_down_block_res_samples
|
||||||
]
|
]
|
||||||
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
|
brushnet_mid_block_res_sample = torch.mean(
|
||||||
|
brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True
|
||||||
|
)
|
||||||
brushnet_up_block_res_samples = [
|
brushnet_up_block_res_samples = [
|
||||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
|
torch.mean(sample, dim=(2, 3), keepdim=True)
|
||||||
|
for sample in brushnet_up_block_res_samples
|
||||||
]
|
]
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
|
return (
|
||||||
|
brushnet_down_block_res_samples,
|
||||||
|
brushnet_mid_block_res_sample,
|
||||||
|
brushnet_up_block_res_samples,
|
||||||
|
)
|
||||||
|
|
||||||
return BrushNetOutput(
|
return BrushNetOutput(
|
||||||
down_block_res_samples=brushnet_down_block_res_samples,
|
down_block_res_samples=brushnet_down_block_res_samples,
|
||||||
mid_block_res_sample=brushnet_mid_block_res_sample,
|
mid_block_res_sample=brushnet_mid_block_res_sample,
|
||||||
up_block_res_samples=brushnet_up_block_res_samples
|
up_block_res_samples=brushnet_up_block_res_samples,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -122,9 +122,13 @@ class ModelInfo(BaseModel):
|
|||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def support_powerpaint_v2(self) -> bool:
|
def support_powerpaint_v2(self) -> bool:
|
||||||
return self.model_type in [
|
return (
|
||||||
ModelType.DIFFUSERS_SD,
|
self.model_type
|
||||||
]
|
in [
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
]
|
||||||
|
and self.name != POWERPAINT_NAME
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Choices(str, Enum):
|
class Choices(str, Enum):
|
||||||
@ -215,7 +219,6 @@ class SDSampler(str, Enum):
|
|||||||
lcm = "LCM"
|
lcm = "LCM"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PowerPaintTask(Choices):
|
class PowerPaintTask(Choices):
|
||||||
text_guided = "text-guided"
|
text_guided = "text-guided"
|
||||||
context_aware = "context-aware"
|
context_aware = "context-aware"
|
||||||
|
@ -59,6 +59,9 @@ const DiffusionOptions = () => {
|
|||||||
updateExtenderDirection,
|
updateExtenderDirection,
|
||||||
adjustMask,
|
adjustMask,
|
||||||
clearMask,
|
clearMask,
|
||||||
|
updateEnablePowerPaintV2,
|
||||||
|
updateEnableBrushNet,
|
||||||
|
updateEnableControlnet,
|
||||||
] = useStore((state) => [
|
] = useStore((state) => [
|
||||||
state.serverConfig.samplers,
|
state.serverConfig.samplers,
|
||||||
state.settings,
|
state.settings,
|
||||||
@ -71,6 +74,9 @@ const DiffusionOptions = () => {
|
|||||||
state.updateExtenderDirection,
|
state.updateExtenderDirection,
|
||||||
state.adjustMask,
|
state.adjustMask,
|
||||||
state.clearMask,
|
state.clearMask,
|
||||||
|
state.updateEnablePowerPaintV2,
|
||||||
|
state.updateEnableBrushNet,
|
||||||
|
state.updateEnableControlnet,
|
||||||
])
|
])
|
||||||
const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile)
|
const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile)
|
||||||
const negativePromptRef = useRef(null)
|
const negativePromptRef = useRef(null)
|
||||||
@ -114,12 +120,8 @@ const DiffusionOptions = () => {
|
|||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
let disable = settings.enableControlnet
|
|
||||||
let toolTip =
|
let toolTip =
|
||||||
"BrushNet is a plug-and-play image inpainting model with decomposed dual-branch diffusion. It can be used to inpaint images by conditioning on a mask."
|
"BrushNet is a plug-and-play image inpainting model works on any SD1.5 base models."
|
||||||
if (disable) {
|
|
||||||
toolTip = "ControlNet is enabled, BrushNet is disabled."
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col gap-4">
|
<div className="flex flex-col gap-4">
|
||||||
@ -129,20 +131,19 @@ const DiffusionOptions = () => {
|
|||||||
text="BrushNet"
|
text="BrushNet"
|
||||||
url="https://github.com/TencentARC/BrushNet"
|
url="https://github.com/TencentARC/BrushNet"
|
||||||
toolTip={toolTip}
|
toolTip={toolTip}
|
||||||
disabled={disable}
|
|
||||||
/>
|
/>
|
||||||
<Switch
|
<Switch
|
||||||
id="brushnet"
|
id="brushnet"
|
||||||
checked={settings.enableBrushNet}
|
checked={settings.enableBrushNet}
|
||||||
onCheckedChange={(value) => {
|
onCheckedChange={(value) => {
|
||||||
updateSettings({ enableBrushNet: value })
|
updateEnableBrushNet(value)
|
||||||
}}
|
}}
|
||||||
disabled={disable}
|
|
||||||
/>
|
/>
|
||||||
</RowContainer>
|
</RowContainer>
|
||||||
<RowContainer>
|
{/* <RowContainer>
|
||||||
<Slider
|
<Slider
|
||||||
defaultValue={[100]}
|
defaultValue={[100]}
|
||||||
|
className="w-[180px]"
|
||||||
min={1}
|
min={1}
|
||||||
max={100}
|
max={100}
|
||||||
step={1}
|
step={1}
|
||||||
@ -155,14 +156,13 @@ const DiffusionOptions = () => {
|
|||||||
<NumberInput
|
<NumberInput
|
||||||
id="brushnet-weight"
|
id="brushnet-weight"
|
||||||
className="w-[60px] rounded-full"
|
className="w-[60px] rounded-full"
|
||||||
disabled={!settings.enableBrushNet || disable}
|
|
||||||
numberValue={settings.brushnetConditioningScale}
|
numberValue={settings.brushnetConditioningScale}
|
||||||
allowFloat={false}
|
allowFloat={false}
|
||||||
onNumberValueChange={(val) => {
|
onNumberValueChange={(val) => {
|
||||||
updateSettings({ brushnetConditioningScale: val })
|
updateSettings({ brushnetConditioningScale: val })
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</RowContainer>
|
</RowContainer> */}
|
||||||
|
|
||||||
<RowContainer>
|
<RowContainer>
|
||||||
<Select
|
<Select
|
||||||
@ -198,12 +198,8 @@ const DiffusionOptions = () => {
|
|||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
let disable = settings.enableBrushNet
|
|
||||||
let toolTip =
|
let toolTip =
|
||||||
"Using an additional conditioning image to control how an image is generated"
|
"Using an additional conditioning image to control how an image is generated"
|
||||||
if (disable) {
|
|
||||||
toolTip = "BrushNet is enabled, ControlNet is disabled."
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col gap-4">
|
<div className="flex flex-col gap-4">
|
||||||
@ -213,15 +209,13 @@ const DiffusionOptions = () => {
|
|||||||
text="ControlNet"
|
text="ControlNet"
|
||||||
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inpaint#controlnet"
|
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inpaint#controlnet"
|
||||||
toolTip={toolTip}
|
toolTip={toolTip}
|
||||||
disabled={disable}
|
|
||||||
/>
|
/>
|
||||||
<Switch
|
<Switch
|
||||||
id="controlnet"
|
id="controlnet"
|
||||||
checked={settings.enableControlnet}
|
checked={settings.enableControlnet}
|
||||||
onCheckedChange={(value) => {
|
onCheckedChange={(value) => {
|
||||||
updateSettings({ enableControlnet: value })
|
updateEnableControlnet(value)
|
||||||
}}
|
}}
|
||||||
disabled={disable}
|
|
||||||
/>
|
/>
|
||||||
</RowContainer>
|
</RowContainer>
|
||||||
|
|
||||||
@ -233,7 +227,7 @@ const DiffusionOptions = () => {
|
|||||||
min={1}
|
min={1}
|
||||||
max={100}
|
max={100}
|
||||||
step={1}
|
step={1}
|
||||||
disabled={!settings.enableControlnet || disable}
|
disabled={!settings.enableControlnet}
|
||||||
value={[Math.floor(settings.controlnetConditioningScale * 100)]}
|
value={[Math.floor(settings.controlnetConditioningScale * 100)]}
|
||||||
onValueChange={(vals) =>
|
onValueChange={(vals) =>
|
||||||
updateSettings({ controlnetConditioningScale: vals[0] / 100 })
|
updateSettings({ controlnetConditioningScale: vals[0] / 100 })
|
||||||
@ -242,7 +236,7 @@ const DiffusionOptions = () => {
|
|||||||
<NumberInput
|
<NumberInput
|
||||||
id="controlnet-weight"
|
id="controlnet-weight"
|
||||||
className="w-[60px] rounded-full"
|
className="w-[60px] rounded-full"
|
||||||
disabled={!settings.enableControlnet || disable}
|
disabled={!settings.enableControlnet}
|
||||||
numberValue={settings.controlnetConditioningScale}
|
numberValue={settings.controlnetConditioningScale}
|
||||||
allowFloat={false}
|
allowFloat={false}
|
||||||
onNumberValueChange={(val) => {
|
onNumberValueChange={(val) => {
|
||||||
@ -286,12 +280,8 @@ const DiffusionOptions = () => {
|
|||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
let disable = settings.enableBrushNet
|
|
||||||
let toolTip =
|
let toolTip =
|
||||||
"Enable quality image generation in typically 2-4 steps. Suggest disabling guidance_scale by setting it to 0. You can also try values between 1.0 and 2.0. When LCM Lora is enabled, LCMSampler will be used automatically."
|
"Enable quality image generation in typically 2-8 steps. Suggest disabling guidance_scale by setting it to 0. You can also try values between 1.0 and 2.0. When LCM Lora is enabled, LCMSampler will be used automatically."
|
||||||
if (disable) {
|
|
||||||
toolTip = "BrushNet is enabled, LCM Lora is disabled."
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
@ -300,7 +290,6 @@ const DiffusionOptions = () => {
|
|||||||
text="LCM LoRA"
|
text="LCM LoRA"
|
||||||
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm_lora"
|
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm_lora"
|
||||||
toolTip={toolTip}
|
toolTip={toolTip}
|
||||||
disabled={disable}
|
|
||||||
/>
|
/>
|
||||||
<Switch
|
<Switch
|
||||||
id="lcm-lora"
|
id="lcm-lora"
|
||||||
@ -308,7 +297,6 @@ const DiffusionOptions = () => {
|
|||||||
onCheckedChange={(value) => {
|
onCheckedChange={(value) => {
|
||||||
updateSettings({ enableLCMLora: value })
|
updateSettings({ enableLCMLora: value })
|
||||||
}}
|
}}
|
||||||
disabled={disable}
|
|
||||||
/>
|
/>
|
||||||
</RowContainer>
|
</RowContainer>
|
||||||
<Separator />
|
<Separator />
|
||||||
@ -561,10 +549,6 @@ const DiffusionOptions = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const renderPowerPaintTaskType = () => {
|
const renderPowerPaintTaskType = () => {
|
||||||
if (settings.model.name !== POWERPAINT) {
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<RowContainer>
|
<RowContainer>
|
||||||
<LabelTitle
|
<LabelTitle
|
||||||
@ -579,7 +563,7 @@ const DiffusionOptions = () => {
|
|||||||
}}
|
}}
|
||||||
disabled={settings.showExtender}
|
disabled={settings.showExtender}
|
||||||
>
|
>
|
||||||
<SelectTrigger className="w-[140px]">
|
<SelectTrigger className="w-[130px]">
|
||||||
<SelectValue placeholder="Select task" />
|
<SelectValue placeholder="Select task" />
|
||||||
</SelectTrigger>
|
</SelectTrigger>
|
||||||
<SelectContent align="end">
|
<SelectContent align="end">
|
||||||
@ -587,6 +571,7 @@ const DiffusionOptions = () => {
|
|||||||
{[
|
{[
|
||||||
PowerPaintTask.text_guided,
|
PowerPaintTask.text_guided,
|
||||||
PowerPaintTask.object_remove,
|
PowerPaintTask.object_remove,
|
||||||
|
PowerPaintTask.context_aware,
|
||||||
PowerPaintTask.shape_guided,
|
PowerPaintTask.shape_guided,
|
||||||
].map((task) => (
|
].map((task) => (
|
||||||
<SelectItem key={task} value={task}>
|
<SelectItem key={task} value={task}>
|
||||||
@ -600,6 +585,44 @@ const DiffusionOptions = () => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const renderPowerPaintV1 = () => {
|
||||||
|
if (settings.model.name !== POWERPAINT) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{renderPowerPaintTaskType()}
|
||||||
|
<Separator />
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const renderPowerPaintV2 = () => {
|
||||||
|
if (settings.model.support_powerpaint_v2 === false) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<RowContainer>
|
||||||
|
<LabelTitle
|
||||||
|
text="PowerPaint V2"
|
||||||
|
toolTip="PowerPaint is a plug-and-play image inpainting model works on any SD1.5 base models."
|
||||||
|
/>
|
||||||
|
<Switch
|
||||||
|
id="powerpaint-v2"
|
||||||
|
checked={settings.enablePowerPaintV2}
|
||||||
|
onCheckedChange={(value) => {
|
||||||
|
updateEnablePowerPaintV2(value)
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</RowContainer>
|
||||||
|
{renderPowerPaintTaskType()}
|
||||||
|
<Separator />
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
const renderSteps = () => {
|
const renderSteps = () => {
|
||||||
return (
|
return (
|
||||||
<RowContainer>
|
<RowContainer>
|
||||||
@ -868,7 +891,7 @@ const DiffusionOptions = () => {
|
|||||||
{renderMaskBlur()}
|
{renderMaskBlur()}
|
||||||
{renderMaskAdjuster()}
|
{renderMaskAdjuster()}
|
||||||
{renderMatchHistograms()}
|
{renderMatchHistograms()}
|
||||||
{renderPowerPaintTaskType()}
|
{renderPowerPaintV1()}
|
||||||
{renderSteps()}
|
{renderSteps()}
|
||||||
{renderGuidanceScale()}
|
{renderGuidanceScale()}
|
||||||
{renderP2PImageGuidanceScale()}
|
{renderP2PImageGuidanceScale()}
|
||||||
@ -878,6 +901,7 @@ const DiffusionOptions = () => {
|
|||||||
{renderNegativePrompt()}
|
{renderNegativePrompt()}
|
||||||
<Separator />
|
<Separator />
|
||||||
{renderBrushNetSetting()}
|
{renderBrushNetSetting()}
|
||||||
|
{renderPowerPaintV2()}
|
||||||
{renderConterNetSetting()}
|
{renderConterNetSetting()}
|
||||||
{renderLCMLora()}
|
{renderLCMLora()}
|
||||||
{renderPaintByExample()}
|
{renderPaintByExample()}
|
||||||
|
@ -22,7 +22,7 @@ const SelectTrigger = React.forwardRef<
|
|||||||
<SelectPrimitive.Trigger
|
<SelectPrimitive.Trigger
|
||||||
ref={ref}
|
ref={ref}
|
||||||
className={cn(
|
className={cn(
|
||||||
"flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent px-3 py-2 text-sm shadow-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-1 focus:ring-ring disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1",
|
"flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent pl-2 pr-1 py-2 text-sm shadow-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-1 focus:ring-ring disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1",
|
||||||
className
|
className
|
||||||
)}
|
)}
|
||||||
tabIndex={-1}
|
tabIndex={-1}
|
||||||
|
@ -79,6 +79,7 @@ export default async function inpaint(
|
|||||||
enable_brushnet: settings.enableBrushNet,
|
enable_brushnet: settings.enableBrushNet,
|
||||||
brushnet_method: settings.brushnetMethod ? settings.brushnetMethod : "",
|
brushnet_method: settings.brushnetMethod ? settings.brushnetMethod : "",
|
||||||
brushnet_conditioning_scale: settings.brushnetConditioningScale,
|
brushnet_conditioning_scale: settings.brushnetConditioningScale,
|
||||||
|
enable_powerpaint_v2: settings.enablePowerPaintV2,
|
||||||
powerpaint_task: settings.showExtender
|
powerpaint_task: settings.showExtender
|
||||||
? PowerPaintTask.outpainting
|
? PowerPaintTask.outpainting
|
||||||
: settings.powerpaintTask,
|
: settings.powerpaintTask,
|
||||||
|
@ -106,6 +106,7 @@ export type Settings = {
|
|||||||
enableLCMLora: boolean
|
enableLCMLora: boolean
|
||||||
|
|
||||||
// PowerPaint
|
// PowerPaint
|
||||||
|
enablePowerPaintV2: boolean
|
||||||
powerpaintTask: PowerPaintTask
|
powerpaintTask: PowerPaintTask
|
||||||
|
|
||||||
// AdjustMask
|
// AdjustMask
|
||||||
@ -194,6 +195,12 @@ type AppAction = {
|
|||||||
setServerConfig: (newValue: ServerConfig) => void
|
setServerConfig: (newValue: ServerConfig) => void
|
||||||
setSeed: (newValue: number) => void
|
setSeed: (newValue: number) => void
|
||||||
updateSettings: (newSettings: Partial<Settings>) => void
|
updateSettings: (newSettings: Partial<Settings>) => void
|
||||||
|
|
||||||
|
// 互斥
|
||||||
|
updateEnablePowerPaintV2: (newValue: boolean) => void
|
||||||
|
updateEnableBrushNet: (newValue: boolean) => void
|
||||||
|
updateEnableControlnet: (newValue: boolean) => void
|
||||||
|
|
||||||
setModel: (newModel: ModelInfo) => void
|
setModel: (newModel: ModelInfo) => void
|
||||||
updateFileManagerState: (newState: Partial<FileManagerState>) => void
|
updateFileManagerState: (newState: Partial<FileManagerState>) => void
|
||||||
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
|
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
|
||||||
@ -311,6 +318,7 @@ const defaultValues: AppState = {
|
|||||||
support_brushnet: false,
|
support_brushnet: false,
|
||||||
support_strength: false,
|
support_strength: false,
|
||||||
support_outpainting: false,
|
support_outpainting: false,
|
||||||
|
support_powerpaint_v2: false,
|
||||||
controlnets: [],
|
controlnets: [],
|
||||||
brushnets: [],
|
brushnets: [],
|
||||||
support_lcm_lora: false,
|
support_lcm_lora: false,
|
||||||
@ -425,6 +433,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
if (
|
if (
|
||||||
get().settings.model.support_outpainting &&
|
get().settings.model.support_outpainting &&
|
||||||
settings.showExtender &&
|
settings.showExtender &&
|
||||||
|
extenderState.x === 0 &&
|
||||||
|
extenderState.y === 0 &&
|
||||||
extenderState.height === imageHeight &&
|
extenderState.height === imageHeight &&
|
||||||
extenderState.width === imageWidth
|
extenderState.width === imageWidth
|
||||||
) {
|
) {
|
||||||
@ -798,6 +808,38 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
})
|
})
|
||||||
},
|
},
|
||||||
|
|
||||||
|
updateEnablePowerPaintV2: (newValue: boolean) => {
|
||||||
|
get().updateSettings({ enablePowerPaintV2: newValue })
|
||||||
|
if (newValue) {
|
||||||
|
get().updateSettings({
|
||||||
|
enableBrushNet: false,
|
||||||
|
enableControlnet: false,
|
||||||
|
enableLCMLora: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
updateEnableBrushNet: (newValue: boolean) => {
|
||||||
|
get().updateSettings({ enableBrushNet: newValue })
|
||||||
|
if (newValue) {
|
||||||
|
get().updateSettings({
|
||||||
|
enablePowerPaintV2: false,
|
||||||
|
enableControlnet: false,
|
||||||
|
enableLCMLora: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
updateEnableControlnet(newValue) {
|
||||||
|
get().updateSettings({ enableControlnet: newValue })
|
||||||
|
if (newValue) {
|
||||||
|
get().updateSettings({
|
||||||
|
enablePowerPaintV2: false,
|
||||||
|
enableBrushNet: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
setModel: (newModel: ModelInfo) => {
|
setModel: (newModel: ModelInfo) => {
|
||||||
set((state) => {
|
set((state) => {
|
||||||
state.settings.model = newModel
|
state.settings.model = newModel
|
||||||
|
@ -49,6 +49,7 @@ export interface ModelInfo {
|
|||||||
support_outpainting: boolean
|
support_outpainting: boolean
|
||||||
support_controlnet: boolean
|
support_controlnet: boolean
|
||||||
support_brushnet: boolean
|
support_brushnet: boolean
|
||||||
|
support_powerpaint_v2: boolean
|
||||||
controlnets: string[]
|
controlnets: string[]
|
||||||
brushnets: string[]
|
brushnets: string[]
|
||||||
support_lcm_lora: boolean
|
support_lcm_lora: boolean
|
||||||
@ -123,6 +124,7 @@ export enum ExtenderDirection {
|
|||||||
export enum PowerPaintTask {
|
export enum PowerPaintTask {
|
||||||
text_guided = "text-guided",
|
text_guided = "text-guided",
|
||||||
shape_guided = "shape-guided",
|
shape_guided = "shape-guided",
|
||||||
|
context_aware = "context-aware",
|
||||||
object_remove = "object-remove",
|
object_remove = "object-remove",
|
||||||
outpainting = "outpainting",
|
outpainting = "outpainting",
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user