update
This commit is contained in:
parent
017a3d68fd
commit
80ee1b9941
@ -1,6 +1,9 @@
|
||||
from itertools import chain
|
||||
|
||||
import PIL.Image
|
||||
import cv2
|
||||
import torch
|
||||
from iopaint.model.original_sd_configs import get_config_files
|
||||
from loguru import logger
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
import numpy as np
|
||||
@ -14,9 +17,15 @@ from ..utils import (
|
||||
handle_from_pretrained_exceptions,
|
||||
)
|
||||
from .powerpaint_tokenizer import task_to_prompt
|
||||
from iopaint.schema import InpaintRequest
|
||||
from iopaint.schema import InpaintRequest, ModelType
|
||||
from .v2.BrushNet_CA import BrushNetModel
|
||||
from .v2.unet_2d_condition import UNet2DConditionModel
|
||||
from .v2.unet_2d_condition import UNet2DConditionModel_forward
|
||||
from .v2.unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D_forward,
|
||||
DownBlock2D_forward,
|
||||
CrossAttnUpBlock2D_forward,
|
||||
UpBlock2D_forward,
|
||||
)
|
||||
|
||||
|
||||
class PowerPaintV2(DiffusionInpaintModel):
|
||||
@ -50,14 +59,7 @@ class PowerPaintV2(DiffusionInpaintModel):
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=model_kwargs["local_files_only"],
|
||||
)
|
||||
unet = handle_from_pretrained_exceptions(
|
||||
UNet2DConditionModel.from_pretrained,
|
||||
pretrained_model_name_or_path=self.model_id_or_path,
|
||||
subfolder="unet",
|
||||
variant="fp16",
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=model_kwargs["local_files_only"],
|
||||
)
|
||||
|
||||
brushnet = BrushNetModel.from_pretrained(
|
||||
self.hf_model_id,
|
||||
subfolder="PowerPaint_Brushnet",
|
||||
@ -65,11 +67,27 @@ class PowerPaintV2(DiffusionInpaintModel):
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=model_kwargs["local_files_only"],
|
||||
)
|
||||
|
||||
if self.model_info.is_single_file_diffusers:
|
||||
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||
model_kwargs["num_in_channels"] = 4
|
||||
else:
|
||||
model_kwargs["num_in_channels"] = 9
|
||||
|
||||
pipe = StableDiffusionPowerPaintBrushNetPipeline.from_single_file(
|
||||
self.model_id_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
load_safety_checker=False,
|
||||
original_config_file=get_config_files()["v1"],
|
||||
brushnet=brushnet,
|
||||
text_encoder_brushnet=text_encoder_brushnet,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
pipe = handle_from_pretrained_exceptions(
|
||||
StableDiffusionPowerPaintBrushNetPipeline.from_pretrained,
|
||||
pretrained_model_name_or_path=self.model_id_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
unet=unet,
|
||||
brushnet=brushnet,
|
||||
text_encoder_brushnet=text_encoder_brushnet,
|
||||
variant="fp16",
|
||||
@ -95,6 +113,34 @@ class PowerPaintV2(DiffusionInpaintModel):
|
||||
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
|
||||
self.model.unet.forward = UNet2DConditionModel_forward.__get__(
|
||||
self.model.unet, self.model.unet.__class__
|
||||
)
|
||||
|
||||
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
|
||||
for down_block in chain(
|
||||
self.model.unet.down_blocks, self.model.brushnet.down_blocks
|
||||
):
|
||||
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
|
||||
down_block.forward = CrossAttnDownBlock2D_forward.__get__(
|
||||
down_block, down_block.__class__
|
||||
)
|
||||
else:
|
||||
down_block.forward = DownBlock2D_forward.__get__(
|
||||
down_block, down_block.__class__
|
||||
)
|
||||
|
||||
for up_block in chain(self.model.unet.up_blocks, self.model.brushnet.up_blocks):
|
||||
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
||||
up_block.forward = CrossAttnUpBlock2D_forward.__get__(
|
||||
up_block, up_block.__class__
|
||||
)
|
||||
else:
|
||||
up_block.forward = UpBlock2D_forward.__get__(
|
||||
up_block, up_block.__class__
|
||||
)
|
||||
|
||||
def forward(self, image, mask, config: InpaintRequest):
|
||||
"""Input image and output image have same size
|
||||
image: [H, W, C] RGB
|
||||
@ -129,11 +175,10 @@ class PowerPaintV2(DiffusionInpaintModel):
|
||||
brushnet_conditioning_scale=1.0,
|
||||
guidance_scale=config.sd_guidance_scale,
|
||||
output_type="np",
|
||||
callback=self.callback,
|
||||
callback_on_step_end=self.callback,
|
||||
height=img_h,
|
||||
width=img_w,
|
||||
generator=torch.manual_seed(config.sd_seed),
|
||||
callback_steps=1,
|
||||
).images[0]
|
||||
|
||||
output = (output * 255).round().astype("uint8")
|
||||
|
@ -2,6 +2,14 @@ from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.unet_2d_blocks import (
|
||||
get_down_block,
|
||||
get_mid_block,
|
||||
get_up_block,
|
||||
CrossAttnDownBlock2D,
|
||||
DownBlock2D,
|
||||
)
|
||||
from torch import nn
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
@ -13,18 +21,14 @@ from diffusers.models.attention_processor import (
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, \
|
||||
TimestepEmbedding, Timesteps
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from .unet_2d_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
DownBlock2D,
|
||||
get_down_block,
|
||||
get_mid_block,
|
||||
get_up_block
|
||||
from diffusers.models.embeddings import (
|
||||
TextImageProjection,
|
||||
TextImageTimeEmbedding,
|
||||
TextTimeEmbedding,
|
||||
TimestepEmbedding,
|
||||
Timesteps,
|
||||
)
|
||||
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@ -145,7 +149,10 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str, ...] = (
|
||||
"UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
@ -170,7 +177,12 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
brushnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
|
||||
16,
|
||||
32,
|
||||
96,
|
||||
256,
|
||||
),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
@ -195,25 +207,33 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||
if not isinstance(only_cross_attention, bool) and len(
|
||||
only_cross_attention
|
||||
) != len(down_block_types):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
|
||||
down_block_types
|
||||
):
|
||||
raise ValueError(
|
||||
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||
)
|
||||
|
||||
if isinstance(transformer_layers_per_block, int):
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||
transformer_layers_per_block = [transformer_layers_per_block] * len(
|
||||
down_block_types
|
||||
)
|
||||
|
||||
# input
|
||||
conv_in_kernel = 3
|
||||
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||
self.conv_in_condition = nn.Conv2d(
|
||||
in_channels + conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel,
|
||||
padding=conv_in_padding
|
||||
in_channels + conditioning_channels,
|
||||
block_out_channels[0],
|
||||
kernel_size=conv_in_kernel,
|
||||
padding=conv_in_padding,
|
||||
)
|
||||
|
||||
# time
|
||||
@ -229,7 +249,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
||||
encoder_hid_dim_type = "text_proj"
|
||||
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
||||
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
||||
logger.info(
|
||||
"encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
|
||||
)
|
||||
|
||||
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
||||
raise ValueError(
|
||||
@ -274,7 +296,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
||||
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
||||
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
||||
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
self.class_embedding = TimestepEmbedding(
|
||||
projection_class_embeddings_input_dim, time_embed_dim
|
||||
)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
@ -285,21 +309,31 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
text_time_embedding_from_dim = cross_attention_dim
|
||||
|
||||
self.add_embedding = TextTimeEmbedding(
|
||||
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
||||
text_time_embedding_from_dim,
|
||||
time_embed_dim,
|
||||
num_heads=addition_embed_type_num_heads,
|
||||
)
|
||||
elif addition_embed_type == "text_image":
|
||||
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
||||
self.add_embedding = TextImageTimeEmbedding(
|
||||
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
||||
text_embed_dim=cross_attention_dim,
|
||||
image_embed_dim=cross_attention_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
)
|
||||
elif addition_embed_type == "text_time":
|
||||
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
||||
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||
self.add_time_proj = Timesteps(
|
||||
addition_time_embed_dim, flip_sin_to_cos, freq_shift
|
||||
)
|
||||
self.add_embedding = TimestepEmbedding(
|
||||
projection_class_embeddings_input_dim, time_embed_dim
|
||||
)
|
||||
|
||||
elif addition_embed_type is not None:
|
||||
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
||||
raise ValueError(
|
||||
f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
|
||||
)
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.brushnet_down_blocks = nn.ModuleList([])
|
||||
@ -338,7 +372,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
resnet_groups=norm_num_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
num_attention_heads=num_attention_heads[i],
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
attention_head_dim=attention_head_dim[i]
|
||||
if attention_head_dim[i] is not None
|
||||
else output_channel,
|
||||
downsample_padding=downsample_padding,
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
@ -348,12 +384,16 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
for _ in range(layers_per_block):
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = nn.Conv2d(
|
||||
output_channel, output_channel, kernel_size=1
|
||||
)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_down_blocks.append(brushnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = nn.Conv2d(
|
||||
output_channel, output_channel, kernel_size=1
|
||||
)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_down_blocks.append(brushnet_block)
|
||||
|
||||
@ -386,7 +426,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
# up
|
||||
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||
reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
|
||||
reversed_transformer_layers_per_block = list(
|
||||
reversed(transformer_layers_per_block)
|
||||
)
|
||||
only_cross_attention = list(reversed(only_cross_attention))
|
||||
|
||||
output_channel = reversed_block_out_channels[0]
|
||||
@ -399,7 +441,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
|
||||
prev_output_channel = output_channel
|
||||
output_channel = reversed_block_out_channels[i]
|
||||
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||
input_channel = reversed_block_out_channels[
|
||||
min(i + 1, len(block_out_channels) - 1)
|
||||
]
|
||||
|
||||
# add upsample block for all BUT final layer
|
||||
if not is_final_block:
|
||||
@ -427,18 +471,24 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||
attention_head_dim=attention_head_dim[i]
|
||||
if attention_head_dim[i] is not None
|
||||
else output_channel,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
for _ in range(layers_per_block + 1):
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = nn.Conv2d(
|
||||
output_channel, output_channel, kernel_size=1
|
||||
)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_up_blocks.append(brushnet_block)
|
||||
|
||||
if not is_final_block:
|
||||
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||
brushnet_block = nn.Conv2d(
|
||||
output_channel, output_channel, kernel_size=1
|
||||
)
|
||||
brushnet_block = zero_module(brushnet_block)
|
||||
self.brushnet_up_blocks.append(brushnet_block)
|
||||
|
||||
@ -447,7 +497,12 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
cls,
|
||||
unet: UNet2DConditionModel,
|
||||
brushnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (
|
||||
16,
|
||||
32,
|
||||
96,
|
||||
256,
|
||||
),
|
||||
load_weights_from_unet: bool = True,
|
||||
conditioning_channels: int = 5,
|
||||
):
|
||||
@ -460,13 +515,27 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
where applicable.
|
||||
"""
|
||||
transformer_layers_per_block = (
|
||||
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
||||
unet.config.transformer_layers_per_block
|
||||
if "transformer_layers_per_block" in unet.config
|
||||
else 1
|
||||
)
|
||||
encoder_hid_dim = (
|
||||
unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
||||
)
|
||||
encoder_hid_dim_type = (
|
||||
unet.config.encoder_hid_dim_type
|
||||
if "encoder_hid_dim_type" in unet.config
|
||||
else None
|
||||
)
|
||||
addition_embed_type = (
|
||||
unet.config.addition_embed_type
|
||||
if "addition_embed_type" in unet.config
|
||||
else None
|
||||
)
|
||||
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
||||
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
||||
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
||||
addition_time_embed_dim = (
|
||||
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
||||
unet.config.addition_time_embed_dim
|
||||
if "addition_time_embed_dim" in unet.config
|
||||
else None
|
||||
)
|
||||
|
||||
brushnet = cls(
|
||||
@ -475,14 +544,21 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||
freq_shift=unet.config.freq_shift,
|
||||
# down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
|
||||
down_block_types=["CrossAttnDownBlock2D",
|
||||
down_block_types=[
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D", ],
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
],
|
||||
# mid_block_type='MidBlock2D',
|
||||
mid_block_type="UNetMidBlock2DCrossAttn",
|
||||
# up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
|
||||
up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
|
||||
up_block_types=[
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
],
|
||||
only_cross_attention=unet.config.only_cross_attention,
|
||||
block_out_channels=unet.config.block_out_channels,
|
||||
layers_per_block=unet.config.layers_per_block,
|
||||
@ -510,21 +586,33 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
if load_weights_from_unet:
|
||||
conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
|
||||
conv_in_condition_weight = torch.zeros_like(
|
||||
brushnet.conv_in_condition.weight
|
||||
)
|
||||
conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
|
||||
conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
|
||||
brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
|
||||
brushnet.conv_in_condition.weight = torch.nn.Parameter(
|
||||
conv_in_condition_weight
|
||||
)
|
||||
brushnet.conv_in_condition.bias = unet.conv_in.bias
|
||||
|
||||
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
||||
|
||||
if brushnet.class_embedding:
|
||||
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
||||
brushnet.class_embedding.load_state_dict(
|
||||
unet.class_embedding.state_dict()
|
||||
)
|
||||
|
||||
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
|
||||
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
|
||||
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
|
||||
brushnet.down_blocks.load_state_dict(
|
||||
unet.down_blocks.state_dict(), strict=False
|
||||
)
|
||||
brushnet.mid_block.load_state_dict(
|
||||
unet.mid_block.state_dict(), strict=False
|
||||
)
|
||||
brushnet.up_blocks.load_state_dict(
|
||||
unet.up_blocks.state_dict(), strict=False
|
||||
)
|
||||
|
||||
return brushnet.to(unet.dtype)
|
||||
|
||||
@ -539,9 +627,15 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
def fn_recursive_add_processors(
|
||||
name: str,
|
||||
module: torch.nn.Module,
|
||||
processors: Dict[str, AttentionProcessor],
|
||||
):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
processors[f"{name}.processor"] = module.get_processor(
|
||||
return_deprecated_lora=True
|
||||
)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
@ -554,7 +648,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@ -593,9 +689,15 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
if all(
|
||||
proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
|
||||
for proc in self.attn_processors.values()
|
||||
):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
elif all(
|
||||
proc.__class__ in CROSS_ATTENTION_PROCESSORS
|
||||
for proc in self.attn_processors.values()
|
||||
):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -642,7 +744,11 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
# make smallest slice possible
|
||||
slice_size = num_sliceable_layers * [1]
|
||||
|
||||
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||
slice_size = (
|
||||
num_sliceable_layers * [slice_size]
|
||||
if not isinstance(slice_size, list)
|
||||
else slice_size
|
||||
)
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
@ -659,7 +765,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
def fn_recursive_set_attention_slice(
|
||||
module: torch.nn.Module, slice_size: List[int]
|
||||
):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
@ -737,7 +845,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
elif channel_order == "bgr":
|
||||
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
|
||||
else:
|
||||
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
|
||||
raise ValueError(
|
||||
f"unknown `brushnet_conditioning_channel_order`: {channel_order}"
|
||||
)
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
@ -773,7 +883,9 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
raise ValueError(
|
||||
"class_labels should be provided when num_class_embeds > 0"
|
||||
)
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
@ -812,7 +924,10 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||
if (
|
||||
hasattr(downsample_block, "has_cross_attention")
|
||||
and downsample_block.has_cross_attention
|
||||
):
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
@ -827,13 +942,20 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 4. PaintingNet down blocks
|
||||
brushnet_down_block_res_samples = ()
|
||||
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
|
||||
for down_block_res_sample, brushnet_down_block in zip(
|
||||
down_block_res_samples, self.brushnet_down_blocks
|
||||
):
|
||||
down_block_res_sample = brushnet_down_block(down_block_res_sample)
|
||||
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
|
||||
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (
|
||||
down_block_res_sample,
|
||||
)
|
||||
|
||||
# 5. mid
|
||||
if self.mid_block is not None:
|
||||
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||
if (
|
||||
hasattr(self.mid_block, "has_cross_attention")
|
||||
and self.mid_block.has_cross_attention
|
||||
):
|
||||
sample = self.mid_block(
|
||||
sample,
|
||||
emb,
|
||||
@ -852,15 +974,20 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||
down_block_res_samples = down_block_res_samples[
|
||||
: -len(upsample_block.resnets)
|
||||
]
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||
if (
|
||||
hasattr(upsample_block, "has_cross_attention")
|
||||
and upsample_block.has_cross_attention
|
||||
):
|
||||
sample, up_res_samples = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
@ -869,7 +996,7 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
return_res_samples=True
|
||||
return_res_samples=True,
|
||||
)
|
||||
else:
|
||||
sample, up_res_samples = upsample_block(
|
||||
@ -877,53 +1004,87 @@ class BrushNetModel(ModelMixin, ConfigMixin):
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
upsample_size=upsample_size,
|
||||
return_res_samples=True
|
||||
return_res_samples=True,
|
||||
)
|
||||
|
||||
up_block_res_samples += up_res_samples
|
||||
|
||||
# 8. BrushNet up blocks
|
||||
brushnet_up_block_res_samples = ()
|
||||
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
|
||||
for up_block_res_sample, brushnet_up_block in zip(
|
||||
up_block_res_samples, self.brushnet_up_blocks
|
||||
):
|
||||
up_block_res_sample = brushnet_up_block(up_block_res_sample)
|
||||
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
|
||||
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (
|
||||
up_block_res_sample,
|
||||
)
|
||||
|
||||
# 6. scaling
|
||||
if guess_mode and not self.config.global_pool_conditions:
|
||||
scales = torch.logspace(-1, 0,
|
||||
len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
|
||||
device=sample.device) # 0.1 to 1.0
|
||||
scales = torch.logspace(
|
||||
-1,
|
||||
0,
|
||||
len(brushnet_down_block_res_samples)
|
||||
+ 1
|
||||
+ len(brushnet_up_block_res_samples),
|
||||
device=sample.device,
|
||||
) # 0.1 to 1.0
|
||||
scales = scales * conditioning_scale
|
||||
|
||||
brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples,
|
||||
scales[:len(
|
||||
brushnet_down_block_res_samples)])]
|
||||
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
|
||||
brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples,
|
||||
scales[
|
||||
len(brushnet_down_block_res_samples) + 1:])]
|
||||
brushnet_down_block_res_samples = [
|
||||
sample * scale
|
||||
for sample, scale in zip(
|
||||
brushnet_down_block_res_samples,
|
||||
scales[: len(brushnet_down_block_res_samples)],
|
||||
)
|
||||
]
|
||||
brushnet_mid_block_res_sample = (
|
||||
brushnet_mid_block_res_sample
|
||||
* scales[len(brushnet_down_block_res_samples)]
|
||||
)
|
||||
brushnet_up_block_res_samples = [
|
||||
sample * scale
|
||||
for sample, scale in zip(
|
||||
brushnet_up_block_res_samples,
|
||||
scales[len(brushnet_down_block_res_samples) + 1 :],
|
||||
)
|
||||
]
|
||||
else:
|
||||
brushnet_down_block_res_samples = [sample * conditioning_scale for sample in
|
||||
brushnet_down_block_res_samples]
|
||||
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
|
||||
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
|
||||
brushnet_down_block_res_samples = [
|
||||
sample * conditioning_scale
|
||||
for sample in brushnet_down_block_res_samples
|
||||
]
|
||||
brushnet_mid_block_res_sample = (
|
||||
brushnet_mid_block_res_sample * conditioning_scale
|
||||
)
|
||||
brushnet_up_block_res_samples = [
|
||||
sample * conditioning_scale for sample in brushnet_up_block_res_samples
|
||||
]
|
||||
|
||||
if self.config.global_pool_conditions:
|
||||
brushnet_down_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True)
|
||||
for sample in brushnet_down_block_res_samples
|
||||
]
|
||||
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||
brushnet_mid_block_res_sample = torch.mean(
|
||||
brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True
|
||||
)
|
||||
brushnet_up_block_res_samples = [
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
|
||||
torch.mean(sample, dim=(2, 3), keepdim=True)
|
||||
for sample in brushnet_up_block_res_samples
|
||||
]
|
||||
|
||||
if not return_dict:
|
||||
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
|
||||
return (
|
||||
brushnet_down_block_res_samples,
|
||||
brushnet_mid_block_res_sample,
|
||||
brushnet_up_block_res_samples,
|
||||
)
|
||||
|
||||
return BrushNetOutput(
|
||||
down_block_res_samples=brushnet_down_block_res_samples,
|
||||
mid_block_res_sample=brushnet_mid_block_res_sample,
|
||||
up_block_res_samples=brushnet_up_block_res_samples
|
||||
up_block_res_samples=brushnet_up_block_res_samples,
|
||||
)
|
||||
|
||||
|
||||
|
@ -5,11 +5,21 @@ import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers import StableDiffusionMixin
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
||||
from diffusers import StableDiffusionMixin, UNet2DConditionModel
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection,
|
||||
)
|
||||
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.loaders import (
|
||||
FromSingleFileMixin,
|
||||
IPAdapterMixin,
|
||||
LoraLoaderMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
)
|
||||
from diffusers.models import AutoencoderKL, ImageProjection
|
||||
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
@ -21,13 +31,20 @@ from diffusers.utils import (
|
||||
scale_lora_layers,
|
||||
unscale_lora_layers,
|
||||
)
|
||||
from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
|
||||
from diffusers.utils.torch_utils import (
|
||||
is_compiled_module,
|
||||
is_torch_version,
|
||||
randn_tensor,
|
||||
)
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_output import (
|
||||
StableDiffusionPipelineOutput,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
|
||||
from .BrushNet_CA import BrushNetModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@ -112,7 +129,9 @@ def retrieve_timesteps(
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
accepts_timesteps = "timesteps" in set(
|
||||
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
||||
)
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
@ -220,7 +239,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
image_encoder=image_encoder,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
||||
self.image_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
|
||||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
|
||||
@ -301,21 +322,25 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
)
|
||||
text_input_idsA = text_inputsA.input_ids
|
||||
text_input_idsB = text_inputsB.input_ids
|
||||
untruncated_ids = self.tokenizer(promptA, padding="longest", return_tensors="pt").input_ids
|
||||
untruncated_ids = self.tokenizer(
|
||||
promptA, padding="longest", return_tensors="pt"
|
||||
).input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_idsA.shape[-1] and not torch.equal(
|
||||
text_input_idsA, untruncated_ids
|
||||
):
|
||||
if untruncated_ids.shape[-1] >= text_input_idsA.shape[
|
||||
-1
|
||||
] and not torch.equal(text_input_idsA, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder_brushnet.config,
|
||||
"use_attention_mask") and self.text_encoder_brushnet.config.use_attention_mask:
|
||||
if (
|
||||
hasattr(self.text_encoder_brushnet.config, "use_attention_mask")
|
||||
and self.text_encoder_brushnet.config.use_attention_mask
|
||||
):
|
||||
attention_mask = text_inputsA.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
@ -350,7 +375,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds = prompt_embeds.view(
|
||||
bs_embed * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
@ -379,8 +406,12 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokensA = self.maybe_convert_prompt(uncond_tokensA, self.tokenizer)
|
||||
uncond_tokensB = self.maybe_convert_prompt(uncond_tokensB, self.tokenizer)
|
||||
uncond_tokensA = self.maybe_convert_prompt(
|
||||
uncond_tokensA, self.tokenizer
|
||||
)
|
||||
uncond_tokensB = self.maybe_convert_prompt(
|
||||
uncond_tokensB, self.tokenizer
|
||||
)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_inputA = self.tokenizer(
|
||||
@ -398,8 +429,10 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder_brushnet.config,
|
||||
"use_attention_mask") and self.text_encoder_brushnet.config.use_attention_mask:
|
||||
if (
|
||||
hasattr(self.text_encoder_brushnet.config, "use_attention_mask")
|
||||
and self.text_encoder_brushnet.config.use_attention_mask
|
||||
):
|
||||
attention_mask = uncond_inputA.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
@ -412,7 +445,10 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
uncond_inputB.input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embedsA[0] * (t_nag) + (1 - t_nag) * negative_prompt_embedsB[0]
|
||||
negative_prompt_embeds = (
|
||||
negative_prompt_embedsA[0] * (t_nag)
|
||||
+ (1 - t_nag) * negative_prompt_embedsB[0]
|
||||
)
|
||||
|
||||
# negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
@ -420,10 +456,16 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(
|
||||
dtype=prompt_embeds_dtype, device=device
|
||||
)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
||||
1, num_images_per_prompt, 1
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
@ -511,30 +553,39 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
# print(prompt, text_input_ids)
|
||||
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
untruncated_ids = self.tokenizer(
|
||||
prompt, padding="longest", return_tensors="pt"
|
||||
).input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
||||
-1
|
||||
] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = self.tokenizer.batch_decode(
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
|
||||
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
if (
|
||||
hasattr(self.text_encoder.config, "use_attention_mask")
|
||||
and self.text_encoder.config.use_attention_mask
|
||||
):
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
if clip_skip is None:
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=attention_mask
|
||||
)
|
||||
prompt_embeds = prompt_embeds[0]
|
||||
else:
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
|
||||
text_input_ids.to(device),
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# Access the `hidden_states` first, that contains a tuple of
|
||||
# all the hidden states from the encoder layers. Then index into
|
||||
@ -544,7 +595,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
# representations. The `last_hidden_states` that we typically use for
|
||||
# obtaining the final prompt representations passes through the LayerNorm
|
||||
# layer.
|
||||
prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
|
||||
prompt_embeds = self.text_encoder.text_model.final_layer_norm(
|
||||
prompt_embeds
|
||||
)
|
||||
|
||||
if self.text_encoder is not None:
|
||||
prompt_embeds_dtype = self.text_encoder.dtype
|
||||
@ -558,7 +611,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds = prompt_embeds.view(
|
||||
bs_embed * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
@ -595,7 +650,10 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
)
|
||||
# print("neg: ", uncond_input.input_ids)
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
if (
|
||||
hasattr(self.text_encoder.config, "use_attention_mask")
|
||||
and self.text_encoder.config.use_attention_mask
|
||||
):
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
@ -610,10 +668,16 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(
|
||||
dtype=prompt_embeds_dtype, device=device
|
||||
)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
||||
1, num_images_per_prompt, 1
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
@ -624,7 +688,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
return prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
|
||||
def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
|
||||
def encode_image(
|
||||
self, image, device, num_images_per_prompt, output_hidden_states=None
|
||||
):
|
||||
dtype = next(self.image_encoder.parameters()).dtype
|
||||
|
||||
if not isinstance(image, torch.Tensor):
|
||||
@ -632,14 +698,20 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if output_hidden_states:
|
||||
image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
image_enc_hidden_states = self.image_encoder(
|
||||
image, output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
uncond_image_enc_hidden_states = self.image_encoder(
|
||||
torch.zeros_like(image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
|
||||
uncond_image_enc_hidden_states = (
|
||||
uncond_image_enc_hidden_states.repeat_interleave(
|
||||
num_images_per_prompt, dim=0
|
||||
)
|
||||
)
|
||||
return image_enc_hidden_states, uncond_image_enc_hidden_states
|
||||
else:
|
||||
image_embeds = self.image_encoder(image).image_embeds
|
||||
@ -650,13 +722,20 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
|
||||
def prepare_ip_adapter_image_embeds(
|
||||
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
||||
self,
|
||||
ip_adapter_image,
|
||||
ip_adapter_image_embeds,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
):
|
||||
if ip_adapter_image_embeds is None:
|
||||
if not isinstance(ip_adapter_image, list):
|
||||
ip_adapter_image = [ip_adapter_image]
|
||||
|
||||
if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
|
||||
if len(ip_adapter_image) != len(
|
||||
self.unet.encoder_hid_proj.image_projection_layers
|
||||
):
|
||||
raise ValueError(
|
||||
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
||||
)
|
||||
@ -669,13 +748,17 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
||||
single_ip_adapter_image, device, 1, output_hidden_state
|
||||
)
|
||||
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
||||
single_image_embeds = torch.stack(
|
||||
[single_image_embeds] * num_images_per_prompt, dim=0
|
||||
)
|
||||
single_negative_image_embeds = torch.stack(
|
||||
[single_negative_image_embeds] * num_images_per_prompt, dim=0
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
single_image_embeds = torch.cat(
|
||||
[single_negative_image_embeds, single_image_embeds]
|
||||
)
|
||||
single_image_embeds = single_image_embeds.to(device)
|
||||
|
||||
image_embeds.append(single_image_embeds)
|
||||
@ -684,17 +767,24 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
image_embeds = []
|
||||
for single_image_embeds in ip_adapter_image_embeds:
|
||||
if do_classifier_free_guidance:
|
||||
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
||||
single_negative_image_embeds, single_image_embeds = (
|
||||
single_image_embeds.chunk(2)
|
||||
)
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
num_images_per_prompt,
|
||||
*(repeat_dims * len(single_image_embeds.shape[1:])),
|
||||
)
|
||||
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
||||
num_images_per_prompt,
|
||||
*(repeat_dims * len(single_negative_image_embeds.shape[1:])),
|
||||
)
|
||||
single_image_embeds = torch.cat(
|
||||
[single_negative_image_embeds, single_image_embeds]
|
||||
)
|
||||
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
||||
else:
|
||||
single_image_embeds = single_image_embeds.repeat(
|
||||
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
||||
num_images_per_prompt,
|
||||
*(repeat_dims * len(single_image_embeds.shape[1:])),
|
||||
)
|
||||
image_embeds.append(single_image_embeds)
|
||||
|
||||
@ -706,10 +796,14 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
if torch.is_tensor(image):
|
||||
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
|
||||
feature_extractor_input = self.image_processor.postprocess(
|
||||
image, output_type="pil"
|
||||
)
|
||||
else:
|
||||
feature_extractor_input = self.image_processor.numpy_to_pil(image)
|
||||
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
|
||||
safety_checker_input = self.feature_extractor(
|
||||
feature_extractor_input, return_tensors="pt"
|
||||
).to(device)
|
||||
image, has_nsfw_concept = self.safety_checker(
|
||||
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
|
||||
)
|
||||
@ -734,13 +828,17 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
accepts_eta = "eta" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
accepts_generator = "generator" in set(
|
||||
inspect.signature(self.scheduler.step).parameters.keys()
|
||||
)
|
||||
if accepts_generator:
|
||||
extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
@ -761,14 +859,17 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
control_guidance_end=1.0,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
|
||||
if callback_steps is not None and (
|
||||
not isinstance(callback_steps, int) or callback_steps <= 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
k in self._callback_tensor_inputs
|
||||
for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
@ -783,8 +884,12 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt is not None and (
|
||||
not isinstance(prompt, str) and not isinstance(prompt, list)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
||||
)
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
@ -820,7 +925,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
and isinstance(self.brushnet._orig_mod, BrushNetModel)
|
||||
):
|
||||
if not isinstance(brushnet_conditioning_scale, float):
|
||||
raise TypeError("For single brushnet: `brushnet_conditioning_scale` must be type `float`.")
|
||||
raise TypeError(
|
||||
"For single brushnet: `brushnet_conditioning_scale` must be type `float`."
|
||||
)
|
||||
else:
|
||||
assert False
|
||||
|
||||
@ -841,9 +948,13 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
|
||||
)
|
||||
if start < 0.0:
|
||||
raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
|
||||
raise ValueError(
|
||||
f"control guidance start: {start} can't be smaller than 0."
|
||||
)
|
||||
if end > 1.0:
|
||||
raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
|
||||
raise ValueError(
|
||||
f"control guidance end: {end} can't be larger than 1.0."
|
||||
)
|
||||
|
||||
if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
|
||||
raise ValueError(
|
||||
@ -864,8 +975,12 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
image_is_pil = isinstance(image, PIL.Image.Image)
|
||||
image_is_tensor = isinstance(image, torch.Tensor)
|
||||
image_is_np = isinstance(image, np.ndarray)
|
||||
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
|
||||
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
|
||||
image_is_pil_list = isinstance(image, list) and isinstance(
|
||||
image[0], PIL.Image.Image
|
||||
)
|
||||
image_is_tensor_list = isinstance(image, list) and isinstance(
|
||||
image[0], torch.Tensor
|
||||
)
|
||||
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
|
||||
|
||||
if (
|
||||
@ -883,8 +998,12 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
mask_is_pil = isinstance(mask, PIL.Image.Image)
|
||||
mask_is_tensor = isinstance(mask, torch.Tensor)
|
||||
mask_is_np = isinstance(mask, np.ndarray)
|
||||
mask_is_pil_list = isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image)
|
||||
mask_is_tensor_list = isinstance(mask, list) and isinstance(mask[0], torch.Tensor)
|
||||
mask_is_pil_list = isinstance(mask, list) and isinstance(
|
||||
mask[0], PIL.Image.Image
|
||||
)
|
||||
mask_is_tensor_list = isinstance(mask, list) and isinstance(
|
||||
mask[0], torch.Tensor
|
||||
)
|
||||
mask_is_np_list = isinstance(mask, list) and isinstance(mask[0], np.ndarray)
|
||||
|
||||
if (
|
||||
@ -928,7 +1047,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
do_classifier_free_guidance=False,
|
||||
guess_mode=False,
|
||||
):
|
||||
image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
|
||||
image = self.image_processor.preprocess(image, height=height, width=width).to(
|
||||
dtype=torch.float32
|
||||
)
|
||||
image_batch_size = image.shape[0]
|
||||
|
||||
if image_batch_size == 1:
|
||||
@ -947,8 +1068,23 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
return image.to(device=device, dtype=dtype)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height // self.vae_scale_factor,
|
||||
width // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
@ -1186,14 +1322,26 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
||||
)
|
||||
|
||||
brushnet = self.brushnet._orig_mod if is_compiled_module(self.brushnet) else self.brushnet
|
||||
brushnet = (
|
||||
self.brushnet._orig_mod
|
||||
if is_compiled_module(self.brushnet)
|
||||
else self.brushnet
|
||||
)
|
||||
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
if not isinstance(control_guidance_start, list) and isinstance(
|
||||
control_guidance_end, list
|
||||
):
|
||||
control_guidance_start = len(control_guidance_end) * [
|
||||
control_guidance_start
|
||||
]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(
|
||||
control_guidance_start, list
|
||||
):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(
|
||||
control_guidance_end, list
|
||||
):
|
||||
control_guidance_start, control_guidance_end = (
|
||||
[control_guidance_start],
|
||||
[control_guidance_end],
|
||||
@ -1241,7 +1389,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = (
|
||||
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
|
||||
self.cross_attention_kwargs.get("scale", None)
|
||||
if self.cross_attention_kwargs is not None
|
||||
else None
|
||||
)
|
||||
|
||||
prompt_embeds = self._encode_prompt(
|
||||
@ -1310,7 +1460,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
assert False
|
||||
|
||||
# 5. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps
|
||||
)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare latent variables
|
||||
@ -1330,14 +1482,15 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
# mask_i = transforms.ToPILImage()(image[0:1,:,:,:].squeeze(0))
|
||||
# mask_i.save('_mask.png')
|
||||
# print(brushnet.dtype)
|
||||
conditioning_latents = self.vae.encode(
|
||||
image.to(device=device, dtype=brushnet.dtype)).latent_dist.sample() * self.vae.config.scaling_factor
|
||||
conditioning_latents = (
|
||||
self.vae.encode(
|
||||
image.to(device=device, dtype=brushnet.dtype)
|
||||
).latent_dist.sample()
|
||||
* self.vae.config.scaling_factor
|
||||
)
|
||||
mask = torch.nn.functional.interpolate(
|
||||
original_mask,
|
||||
size=(
|
||||
conditioning_latents.shape[-2],
|
||||
conditioning_latents.shape[-1]
|
||||
)
|
||||
size=(conditioning_latents.shape[-2], conditioning_latents.shape[-1]),
|
||||
)
|
||||
conditioning_latents = torch.concat([conditioning_latents, mask], 1)
|
||||
# image = self.vae.decode(conditioning_latents[:1,:4,:,:] / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
|
||||
@ -1348,7 +1501,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
# 6.5 Optionally get Guidance Scale Embedding
|
||||
timestep_cond = None
|
||||
if self.unet.config.time_cond_proj_dim is not None:
|
||||
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
|
||||
guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(
|
||||
batch_size * num_images_per_prompt
|
||||
)
|
||||
timestep_cond = self.get_guidance_scale_embedding(
|
||||
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
||||
).to(device=device, dtype=latents.dtype)
|
||||
@ -1370,7 +1525,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
brushnet_keep.append(keeps[0] if isinstance(brushnet, BrushNetModel) else keeps)
|
||||
brushnet_keep.append(
|
||||
keeps[0] if isinstance(brushnet, BrushNetModel) else keeps
|
||||
)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
@ -1381,31 +1538,45 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
for i, t in enumerate(timesteps):
|
||||
# Relevant thread:
|
||||
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
|
||||
if (is_unet_compiled and is_brushnet_compiled) and is_torch_higher_equal_2_1:
|
||||
if (
|
||||
is_unet_compiled and is_brushnet_compiled
|
||||
) and is_torch_higher_equal_2_1:
|
||||
torch._inductor.cudagraph_mark_step_begin()
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
latent_model_input = (
|
||||
torch.cat([latents] * 2)
|
||||
if self.do_classifier_free_guidance
|
||||
else latents
|
||||
)
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t
|
||||
)
|
||||
|
||||
# brushnet(s) inference
|
||||
if guess_mode and self.do_classifier_free_guidance:
|
||||
# Infer BrushNet only for the conditional batch.
|
||||
control_model_input = latents
|
||||
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
||||
control_model_input = self.scheduler.scale_model_input(
|
||||
control_model_input, t
|
||||
)
|
||||
brushnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
||||
else:
|
||||
control_model_input = latent_model_input
|
||||
brushnet_prompt_embeds = prompt_embeds
|
||||
|
||||
if isinstance(brushnet_keep[i], list):
|
||||
cond_scale = [c * s for c, s in zip(brushnet_conditioning_scale, brushnet_keep[i])]
|
||||
cond_scale = [
|
||||
c * s
|
||||
for c, s in zip(brushnet_conditioning_scale, brushnet_keep[i])
|
||||
]
|
||||
else:
|
||||
brushnet_cond_scale = brushnet_conditioning_scale
|
||||
if isinstance(brushnet_cond_scale, list):
|
||||
brushnet_cond_scale = brushnet_cond_scale[0]
|
||||
cond_scale = brushnet_cond_scale * brushnet_keep[i]
|
||||
|
||||
down_block_res_samples, mid_block_res_sample, up_block_res_samples = self.brushnet(
|
||||
down_block_res_samples, mid_block_res_sample, up_block_res_samples = (
|
||||
self.brushnet(
|
||||
control_model_input,
|
||||
t,
|
||||
encoder_hidden_states=brushnet_prompt_embeds,
|
||||
@ -1414,14 +1585,23 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
guess_mode=guess_mode,
|
||||
return_dict=False,
|
||||
)
|
||||
)
|
||||
|
||||
if guess_mode and self.do_classifier_free_guidance:
|
||||
# Infered BrushNet only for the conditional batch.
|
||||
# To apply the output of BrushNet to both the unconditional and conditional batches,
|
||||
# add 0 to the unconditional batch to keep it unchanged.
|
||||
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
||||
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
||||
up_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in up_block_res_samples]
|
||||
down_block_res_samples = [
|
||||
torch.cat([torch.zeros_like(d), d])
|
||||
for d in down_block_res_samples
|
||||
]
|
||||
mid_block_res_sample = torch.cat(
|
||||
[torch.zeros_like(mid_block_res_sample), mid_block_res_sample]
|
||||
)
|
||||
up_block_res_samples = [
|
||||
torch.cat([torch.zeros_like(d), d])
|
||||
for d in up_block_res_samples
|
||||
]
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
@ -1440,10 +1620,14 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
||||
)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
@ -1453,10 +1637,14 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop(
|
||||
"negative_prompt_embeds", negative_prompt_embeds
|
||||
)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
if i == len(timesteps) - 1 or (
|
||||
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
@ -1470,10 +1658,14 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
||||
0
|
||||
]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
image = self.vae.decode(
|
||||
latents / self.vae.config.scaling_factor,
|
||||
return_dict=False,
|
||||
generator=generator,
|
||||
)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(
|
||||
image, device, prompt_embeds.dtype
|
||||
)
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
@ -1483,7 +1675,9 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
else:
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
image = self.image_processor.postprocess(
|
||||
image, output_type=output_type, do_denormalize=do_denormalize
|
||||
)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
@ -1491,4 +1685,6 @@ class StableDiffusionPowerPaintBrushNetPipeline(
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
return StableDiffusionPipelineOutput(
|
||||
images=image, nsfw_content_detected=has_nsfw_concept
|
||||
)
|
||||
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -122,9 +122,13 @@ class ModelInfo(BaseModel):
|
||||
@computed_field
|
||||
@property
|
||||
def support_powerpaint_v2(self) -> bool:
|
||||
return self.model_type in [
|
||||
return (
|
||||
self.model_type
|
||||
in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
]
|
||||
and self.name != POWERPAINT_NAME
|
||||
)
|
||||
|
||||
|
||||
class Choices(str, Enum):
|
||||
@ -215,7 +219,6 @@ class SDSampler(str, Enum):
|
||||
lcm = "LCM"
|
||||
|
||||
|
||||
|
||||
class PowerPaintTask(Choices):
|
||||
text_guided = "text-guided"
|
||||
context_aware = "context-aware"
|
||||
|
@ -59,6 +59,9 @@ const DiffusionOptions = () => {
|
||||
updateExtenderDirection,
|
||||
adjustMask,
|
||||
clearMask,
|
||||
updateEnablePowerPaintV2,
|
||||
updateEnableBrushNet,
|
||||
updateEnableControlnet,
|
||||
] = useStore((state) => [
|
||||
state.serverConfig.samplers,
|
||||
state.settings,
|
||||
@ -71,6 +74,9 @@ const DiffusionOptions = () => {
|
||||
state.updateExtenderDirection,
|
||||
state.adjustMask,
|
||||
state.clearMask,
|
||||
state.updateEnablePowerPaintV2,
|
||||
state.updateEnableBrushNet,
|
||||
state.updateEnableControlnet,
|
||||
])
|
||||
const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile)
|
||||
const negativePromptRef = useRef(null)
|
||||
@ -114,12 +120,8 @@ const DiffusionOptions = () => {
|
||||
return null
|
||||
}
|
||||
|
||||
let disable = settings.enableControlnet
|
||||
let toolTip =
|
||||
"BrushNet is a plug-and-play image inpainting model with decomposed dual-branch diffusion. It can be used to inpaint images by conditioning on a mask."
|
||||
if (disable) {
|
||||
toolTip = "ControlNet is enabled, BrushNet is disabled."
|
||||
}
|
||||
"BrushNet is a plug-and-play image inpainting model works on any SD1.5 base models."
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
@ -129,20 +131,19 @@ const DiffusionOptions = () => {
|
||||
text="BrushNet"
|
||||
url="https://github.com/TencentARC/BrushNet"
|
||||
toolTip={toolTip}
|
||||
disabled={disable}
|
||||
/>
|
||||
<Switch
|
||||
id="brushnet"
|
||||
checked={settings.enableBrushNet}
|
||||
onCheckedChange={(value) => {
|
||||
updateSettings({ enableBrushNet: value })
|
||||
updateEnableBrushNet(value)
|
||||
}}
|
||||
disabled={disable}
|
||||
/>
|
||||
</RowContainer>
|
||||
<RowContainer>
|
||||
{/* <RowContainer>
|
||||
<Slider
|
||||
defaultValue={[100]}
|
||||
className="w-[180px]"
|
||||
min={1}
|
||||
max={100}
|
||||
step={1}
|
||||
@ -155,14 +156,13 @@ const DiffusionOptions = () => {
|
||||
<NumberInput
|
||||
id="brushnet-weight"
|
||||
className="w-[60px] rounded-full"
|
||||
disabled={!settings.enableBrushNet || disable}
|
||||
numberValue={settings.brushnetConditioningScale}
|
||||
allowFloat={false}
|
||||
onNumberValueChange={(val) => {
|
||||
updateSettings({ brushnetConditioningScale: val })
|
||||
}}
|
||||
/>
|
||||
</RowContainer>
|
||||
</RowContainer> */}
|
||||
|
||||
<RowContainer>
|
||||
<Select
|
||||
@ -198,12 +198,8 @@ const DiffusionOptions = () => {
|
||||
return null
|
||||
}
|
||||
|
||||
let disable = settings.enableBrushNet
|
||||
let toolTip =
|
||||
"Using an additional conditioning image to control how an image is generated"
|
||||
if (disable) {
|
||||
toolTip = "BrushNet is enabled, ControlNet is disabled."
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-4">
|
||||
@ -213,15 +209,13 @@ const DiffusionOptions = () => {
|
||||
text="ControlNet"
|
||||
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inpaint#controlnet"
|
||||
toolTip={toolTip}
|
||||
disabled={disable}
|
||||
/>
|
||||
<Switch
|
||||
id="controlnet"
|
||||
checked={settings.enableControlnet}
|
||||
onCheckedChange={(value) => {
|
||||
updateSettings({ enableControlnet: value })
|
||||
updateEnableControlnet(value)
|
||||
}}
|
||||
disabled={disable}
|
||||
/>
|
||||
</RowContainer>
|
||||
|
||||
@ -233,7 +227,7 @@ const DiffusionOptions = () => {
|
||||
min={1}
|
||||
max={100}
|
||||
step={1}
|
||||
disabled={!settings.enableControlnet || disable}
|
||||
disabled={!settings.enableControlnet}
|
||||
value={[Math.floor(settings.controlnetConditioningScale * 100)]}
|
||||
onValueChange={(vals) =>
|
||||
updateSettings({ controlnetConditioningScale: vals[0] / 100 })
|
||||
@ -242,7 +236,7 @@ const DiffusionOptions = () => {
|
||||
<NumberInput
|
||||
id="controlnet-weight"
|
||||
className="w-[60px] rounded-full"
|
||||
disabled={!settings.enableControlnet || disable}
|
||||
disabled={!settings.enableControlnet}
|
||||
numberValue={settings.controlnetConditioningScale}
|
||||
allowFloat={false}
|
||||
onNumberValueChange={(val) => {
|
||||
@ -286,12 +280,8 @@ const DiffusionOptions = () => {
|
||||
return null
|
||||
}
|
||||
|
||||
let disable = settings.enableBrushNet
|
||||
let toolTip =
|
||||
"Enable quality image generation in typically 2-4 steps. Suggest disabling guidance_scale by setting it to 0. You can also try values between 1.0 and 2.0. When LCM Lora is enabled, LCMSampler will be used automatically."
|
||||
if (disable) {
|
||||
toolTip = "BrushNet is enabled, LCM Lora is disabled."
|
||||
}
|
||||
"Enable quality image generation in typically 2-8 steps. Suggest disabling guidance_scale by setting it to 0. You can also try values between 1.0 and 2.0. When LCM Lora is enabled, LCMSampler will be used automatically."
|
||||
|
||||
return (
|
||||
<>
|
||||
@ -300,7 +290,6 @@ const DiffusionOptions = () => {
|
||||
text="LCM LoRA"
|
||||
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm_lora"
|
||||
toolTip={toolTip}
|
||||
disabled={disable}
|
||||
/>
|
||||
<Switch
|
||||
id="lcm-lora"
|
||||
@ -308,7 +297,6 @@ const DiffusionOptions = () => {
|
||||
onCheckedChange={(value) => {
|
||||
updateSettings({ enableLCMLora: value })
|
||||
}}
|
||||
disabled={disable}
|
||||
/>
|
||||
</RowContainer>
|
||||
<Separator />
|
||||
@ -561,10 +549,6 @@ const DiffusionOptions = () => {
|
||||
}
|
||||
|
||||
const renderPowerPaintTaskType = () => {
|
||||
if (settings.model.name !== POWERPAINT) {
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<RowContainer>
|
||||
<LabelTitle
|
||||
@ -579,7 +563,7 @@ const DiffusionOptions = () => {
|
||||
}}
|
||||
disabled={settings.showExtender}
|
||||
>
|
||||
<SelectTrigger className="w-[140px]">
|
||||
<SelectTrigger className="w-[130px]">
|
||||
<SelectValue placeholder="Select task" />
|
||||
</SelectTrigger>
|
||||
<SelectContent align="end">
|
||||
@ -587,6 +571,7 @@ const DiffusionOptions = () => {
|
||||
{[
|
||||
PowerPaintTask.text_guided,
|
||||
PowerPaintTask.object_remove,
|
||||
PowerPaintTask.context_aware,
|
||||
PowerPaintTask.shape_guided,
|
||||
].map((task) => (
|
||||
<SelectItem key={task} value={task}>
|
||||
@ -600,6 +585,44 @@ const DiffusionOptions = () => {
|
||||
)
|
||||
}
|
||||
|
||||
const renderPowerPaintV1 = () => {
|
||||
if (settings.model.name !== POWERPAINT) {
|
||||
return null
|
||||
}
|
||||
return (
|
||||
<>
|
||||
{renderPowerPaintTaskType()}
|
||||
<Separator />
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
const renderPowerPaintV2 = () => {
|
||||
if (settings.model.support_powerpaint_v2 === false) {
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<RowContainer>
|
||||
<LabelTitle
|
||||
text="PowerPaint V2"
|
||||
toolTip="PowerPaint is a plug-and-play image inpainting model works on any SD1.5 base models."
|
||||
/>
|
||||
<Switch
|
||||
id="powerpaint-v2"
|
||||
checked={settings.enablePowerPaintV2}
|
||||
onCheckedChange={(value) => {
|
||||
updateEnablePowerPaintV2(value)
|
||||
}}
|
||||
/>
|
||||
</RowContainer>
|
||||
{renderPowerPaintTaskType()}
|
||||
<Separator />
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
const renderSteps = () => {
|
||||
return (
|
||||
<RowContainer>
|
||||
@ -868,7 +891,7 @@ const DiffusionOptions = () => {
|
||||
{renderMaskBlur()}
|
||||
{renderMaskAdjuster()}
|
||||
{renderMatchHistograms()}
|
||||
{renderPowerPaintTaskType()}
|
||||
{renderPowerPaintV1()}
|
||||
{renderSteps()}
|
||||
{renderGuidanceScale()}
|
||||
{renderP2PImageGuidanceScale()}
|
||||
@ -878,6 +901,7 @@ const DiffusionOptions = () => {
|
||||
{renderNegativePrompt()}
|
||||
<Separator />
|
||||
{renderBrushNetSetting()}
|
||||
{renderPowerPaintV2()}
|
||||
{renderConterNetSetting()}
|
||||
{renderLCMLora()}
|
||||
{renderPaintByExample()}
|
||||
|
@ -22,7 +22,7 @@ const SelectTrigger = React.forwardRef<
|
||||
<SelectPrimitive.Trigger
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent px-3 py-2 text-sm shadow-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-1 focus:ring-ring disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1",
|
||||
"flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent pl-2 pr-1 py-2 text-sm shadow-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-1 focus:ring-ring disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1",
|
||||
className
|
||||
)}
|
||||
tabIndex={-1}
|
||||
|
@ -79,6 +79,7 @@ export default async function inpaint(
|
||||
enable_brushnet: settings.enableBrushNet,
|
||||
brushnet_method: settings.brushnetMethod ? settings.brushnetMethod : "",
|
||||
brushnet_conditioning_scale: settings.brushnetConditioningScale,
|
||||
enable_powerpaint_v2: settings.enablePowerPaintV2,
|
||||
powerpaint_task: settings.showExtender
|
||||
? PowerPaintTask.outpainting
|
||||
: settings.powerpaintTask,
|
||||
|
@ -106,6 +106,7 @@ export type Settings = {
|
||||
enableLCMLora: boolean
|
||||
|
||||
// PowerPaint
|
||||
enablePowerPaintV2: boolean
|
||||
powerpaintTask: PowerPaintTask
|
||||
|
||||
// AdjustMask
|
||||
@ -194,6 +195,12 @@ type AppAction = {
|
||||
setServerConfig: (newValue: ServerConfig) => void
|
||||
setSeed: (newValue: number) => void
|
||||
updateSettings: (newSettings: Partial<Settings>) => void
|
||||
|
||||
// 互斥
|
||||
updateEnablePowerPaintV2: (newValue: boolean) => void
|
||||
updateEnableBrushNet: (newValue: boolean) => void
|
||||
updateEnableControlnet: (newValue: boolean) => void
|
||||
|
||||
setModel: (newModel: ModelInfo) => void
|
||||
updateFileManagerState: (newState: Partial<FileManagerState>) => void
|
||||
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
|
||||
@ -311,6 +318,7 @@ const defaultValues: AppState = {
|
||||
support_brushnet: false,
|
||||
support_strength: false,
|
||||
support_outpainting: false,
|
||||
support_powerpaint_v2: false,
|
||||
controlnets: [],
|
||||
brushnets: [],
|
||||
support_lcm_lora: false,
|
||||
@ -425,6 +433,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
if (
|
||||
get().settings.model.support_outpainting &&
|
||||
settings.showExtender &&
|
||||
extenderState.x === 0 &&
|
||||
extenderState.y === 0 &&
|
||||
extenderState.height === imageHeight &&
|
||||
extenderState.width === imageWidth
|
||||
) {
|
||||
@ -798,6 +808,38 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
||||
})
|
||||
},
|
||||
|
||||
updateEnablePowerPaintV2: (newValue: boolean) => {
|
||||
get().updateSettings({ enablePowerPaintV2: newValue })
|
||||
if (newValue) {
|
||||
get().updateSettings({
|
||||
enableBrushNet: false,
|
||||
enableControlnet: false,
|
||||
enableLCMLora: false,
|
||||
})
|
||||
}
|
||||
},
|
||||
|
||||
updateEnableBrushNet: (newValue: boolean) => {
|
||||
get().updateSettings({ enableBrushNet: newValue })
|
||||
if (newValue) {
|
||||
get().updateSettings({
|
||||
enablePowerPaintV2: false,
|
||||
enableControlnet: false,
|
||||
enableLCMLora: false,
|
||||
})
|
||||
}
|
||||
},
|
||||
|
||||
updateEnableControlnet(newValue) {
|
||||
get().updateSettings({ enableControlnet: newValue })
|
||||
if (newValue) {
|
||||
get().updateSettings({
|
||||
enablePowerPaintV2: false,
|
||||
enableBrushNet: false,
|
||||
})
|
||||
}
|
||||
},
|
||||
|
||||
setModel: (newModel: ModelInfo) => {
|
||||
set((state) => {
|
||||
state.settings.model = newModel
|
||||
|
@ -49,6 +49,7 @@ export interface ModelInfo {
|
||||
support_outpainting: boolean
|
||||
support_controlnet: boolean
|
||||
support_brushnet: boolean
|
||||
support_powerpaint_v2: boolean
|
||||
controlnets: string[]
|
||||
brushnets: string[]
|
||||
support_lcm_lora: boolean
|
||||
@ -123,6 +124,7 @@ export enum ExtenderDirection {
|
||||
export enum PowerPaintTask {
|
||||
text_guided = "text-guided",
|
||||
shape_guided = "shape-guided",
|
||||
context_aware = "context-aware",
|
||||
object_remove = "object-remove",
|
||||
outpainting = "outpainting",
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user