Merge branch 'brushnet'
This commit is contained in:
commit
c516a23fd8
@ -3,7 +3,7 @@ from pathlib import Path
|
|||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import psutil
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
@ -35,6 +35,7 @@ def glob_images(path: Path) -> Dict[str, Path]:
|
|||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def batch_inpaint(
|
def batch_inpaint(
|
||||||
model: str,
|
model: str,
|
||||||
device,
|
device,
|
||||||
@ -46,7 +47,7 @@ def batch_inpaint(
|
|||||||
):
|
):
|
||||||
if image.is_dir() and output.is_file():
|
if image.is_dir() and output.is_file():
|
||||||
logger.error(
|
logger.error(
|
||||||
f"invalid --output: when image is a directory, output should be a directory"
|
"invalid --output: when image is a directory, output should be a directory"
|
||||||
)
|
)
|
||||||
exit(-1)
|
exit(-1)
|
||||||
output.mkdir(parents=True, exist_ok=True)
|
output.mkdir(parents=True, exist_ok=True)
|
||||||
@ -54,10 +55,10 @@ def batch_inpaint(
|
|||||||
image_paths = glob_images(image)
|
image_paths = glob_images(image)
|
||||||
mask_paths = glob_images(mask)
|
mask_paths = glob_images(mask)
|
||||||
if len(image_paths) == 0:
|
if len(image_paths) == 0:
|
||||||
logger.error(f"invalid --image: empty image folder")
|
logger.error("invalid --image: empty image folder")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
if len(mask_paths) == 0:
|
if len(mask_paths) == 0:
|
||||||
logger.error(f"invalid --mask: empty mask folder")
|
logger.error("invalid --mask: empty mask folder")
|
||||||
exit(-1)
|
exit(-1)
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
@ -92,9 +93,9 @@ def batch_inpaint(
|
|||||||
|
|
||||||
infos = Image.open(image_p).info
|
infos = Image.open(image_p).info
|
||||||
|
|
||||||
img = cv2.imread(str(image_p))
|
img = np.array(Image.open(image_p).convert("RGB"))
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
mask_img = np.array(Image.open(mask_p).convert("L"))
|
||||||
mask_img = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
|
|
||||||
if mask_img.shape[:2] != img.shape[:2]:
|
if mask_img.shape[:2] != img.shape[:2]:
|
||||||
progress.log(
|
progress.log(
|
||||||
f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}"
|
f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}"
|
||||||
|
@ -6,7 +6,6 @@ KANDINSKY22_NAME = "kandinsky-community/kandinsky-2-2-decoder-inpaint"
|
|||||||
POWERPAINT_NAME = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
|
POWERPAINT_NAME = "Sanster/PowerPaint-V1-stable-diffusion-inpainting"
|
||||||
ANYTEXT_NAME = "Sanster/AnyText"
|
ANYTEXT_NAME = "Sanster/AnyText"
|
||||||
|
|
||||||
|
|
||||||
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
|
DIFFUSERS_SD_CLASS_NAME = "StableDiffusionPipeline"
|
||||||
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
|
DIFFUSERS_SD_INPAINT_CLASS_NAME = "StableDiffusionInpaintPipeline"
|
||||||
DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
|
DIFFUSERS_SDXL_CLASS_NAME = "StableDiffusionXLPipeline"
|
||||||
@ -62,6 +61,11 @@ SD_CONTROLNET_CHOICES: List[str] = [
|
|||||||
"lllyasviel/control_v11f1p_sd15_depth",
|
"lllyasviel/control_v11f1p_sd15_depth",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
SD_BRUSHNET_CHOICES: List[str] = [
|
||||||
|
"Sanster/brushnet_random_mask",
|
||||||
|
"Sanster/brushnet_segmentation_mask"
|
||||||
|
]
|
||||||
|
|
||||||
SD2_CONTROLNET_CHOICES = [
|
SD2_CONTROLNET_CHOICES = [
|
||||||
"thibaud/controlnet-sd21-canny-diffusers",
|
"thibaud/controlnet-sd21-canny-diffusers",
|
||||||
"thibaud/controlnet-sd21-depth-diffusers",
|
"thibaud/controlnet-sd21-depth-diffusers",
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import glob
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
@ -67,7 +68,7 @@ def get_sd_model_type(model_abs_path: str) -> ModelType:
|
|||||||
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
|
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
|
||||||
model_type = ModelType.DIFFUSERS_SD
|
model_type = ModelType.DIFFUSERS_SD
|
||||||
else:
|
else:
|
||||||
raise e
|
logger.info(f"Ignore non sd or sdxl file: {model_abs_path}")
|
||||||
return model_type
|
return model_type
|
||||||
|
|
||||||
|
|
||||||
@ -92,10 +93,10 @@ def get_sdxl_model_type(model_abs_path: str) -> ModelType:
|
|||||||
else:
|
else:
|
||||||
model_type = ModelType.DIFFUSERS_SDXL
|
model_type = ModelType.DIFFUSERS_SDXL
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
|
if "but got torch.Size([320, 4, 3, 3])" in str(e):
|
||||||
model_type = ModelType.DIFFUSERS_SDXL
|
model_type = ModelType.DIFFUSERS_SDXL
|
||||||
else:
|
else:
|
||||||
raise e
|
logger.info(f"Ignore non sd or sdxl file: {model_abs_path}")
|
||||||
return model_type
|
return model_type
|
||||||
|
|
||||||
|
|
||||||
@ -192,7 +193,9 @@ def scan_diffusers_models() -> List[ModelInfo]:
|
|||||||
cache_dir = Path(HF_HUB_CACHE)
|
cache_dir = Path(HF_HUB_CACHE)
|
||||||
# logger.info(f"Scanning diffusers models in {cache_dir}")
|
# logger.info(f"Scanning diffusers models in {cache_dir}")
|
||||||
diffusers_model_names = []
|
diffusers_model_names = []
|
||||||
for it in cache_dir.glob("**/*/model_index.json"):
|
model_index_files = glob.glob(os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True)
|
||||||
|
for it in model_index_files:
|
||||||
|
it = Path(it)
|
||||||
with open(it, "r", encoding="utf-8") as f:
|
with open(it, "r", encoding="utf-8") as f:
|
||||||
try:
|
try:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
@ -238,7 +241,9 @@ def _scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
|
|||||||
cache_dir = Path(cache_dir)
|
cache_dir = Path(cache_dir)
|
||||||
available_models = []
|
available_models = []
|
||||||
diffusers_model_names = []
|
diffusers_model_names = []
|
||||||
for it in cache_dir.glob("**/*/model_index.json"):
|
model_index_files = glob.glob(os.path.join(cache_dir, "**/*", "model_index.json"), recursive=True)
|
||||||
|
for it in model_index_files:
|
||||||
|
it = Path(it)
|
||||||
with open(it, "r", encoding="utf-8") as f:
|
with open(it, "r", encoding="utf-8") as f:
|
||||||
try:
|
try:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
|
@ -13,7 +13,7 @@ from iopaint.helper import (
|
|||||||
switch_mps_device,
|
switch_mps_device,
|
||||||
)
|
)
|
||||||
from iopaint.schema import InpaintRequest, HDStrategy, SDSampler
|
from iopaint.schema import InpaintRequest, HDStrategy, SDSampler
|
||||||
from .helper.g_diffuser_bot import expand_image, expand_image2
|
from .helper.g_diffuser_bot import expand_image
|
||||||
from .utils import get_scheduler
|
from .utils import get_scheduler
|
||||||
|
|
||||||
|
|
||||||
@ -35,8 +35,7 @@ class InpaintModel:
|
|||||||
self.init_model(device, **kwargs)
|
self.init_model(device, **kwargs)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def init_model(self, device, **kwargs):
|
def init_model(self, device, **kwargs): ...
|
||||||
...
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@ -53,8 +52,7 @@ class InpaintModel:
|
|||||||
...
|
...
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def download():
|
def download(): ...
|
||||||
...
|
|
||||||
|
|
||||||
def _pad_forward(self, image, mask, config: InpaintRequest):
|
def _pad_forward(self, image, mask, config: InpaintRequest):
|
||||||
origin_height, origin_width = image.shape[:2]
|
origin_height, origin_width = image.shape[:2]
|
||||||
@ -96,7 +94,7 @@ class InpaintModel:
|
|||||||
# logger.info(f"hd_strategy: {config.hd_strategy}")
|
# logger.info(f"hd_strategy: {config.hd_strategy}")
|
||||||
if config.hd_strategy == HDStrategy.CROP:
|
if config.hd_strategy == HDStrategy.CROP:
|
||||||
if max(image.shape) > config.hd_strategy_crop_trigger_size:
|
if max(image.shape) > config.hd_strategy_crop_trigger_size:
|
||||||
logger.info(f"Run crop strategy")
|
logger.info("Run crop strategy")
|
||||||
boxes = boxes_from_mask(mask)
|
boxes = boxes_from_mask(mask)
|
||||||
crop_result = []
|
crop_result = []
|
||||||
for box in boxes:
|
for box in boxes:
|
||||||
@ -327,14 +325,12 @@ class DiffusionInpaintModel(InpaintModel):
|
|||||||
padding_r = max(0, cropper_r - image_r)
|
padding_r = max(0, cropper_r - image_r)
|
||||||
padding_b = max(0, cropper_b - image_b)
|
padding_b = max(0, cropper_b - image_b)
|
||||||
|
|
||||||
expanded_image, mask_image = expand_image2(
|
expanded_image, mask_image = expand_image(
|
||||||
cropped_image,
|
cropped_image,
|
||||||
left=padding_l,
|
left=padding_l,
|
||||||
top=padding_t,
|
top=padding_t,
|
||||||
right=padding_r,
|
right=padding_r,
|
||||||
bottom=padding_b,
|
bottom=padding_b,
|
||||||
softness=config.sd_outpainting_softness,
|
|
||||||
space=config.sd_outpainting_space,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 最终扩大了的 image, BGR
|
# 最终扩大了的 image, BGR
|
||||||
@ -381,15 +377,6 @@ class DiffusionInpaintModel(InpaintModel):
|
|||||||
interpolation=cv2.INTER_CUBIC,
|
interpolation=cv2.INTER_CUBIC,
|
||||||
)
|
)
|
||||||
|
|
||||||
# blend result, copy from g_diffuser_bot
|
|
||||||
# mask_rgb = 1.0 - np_img_grey_to_rgb(mask / 255.0)
|
|
||||||
# inpaint_result = np.clip(
|
|
||||||
# inpaint_result * (1.0 - mask_rgb) + image * mask_rgb, 0.0, 255.0
|
|
||||||
# )
|
|
||||||
# original_pixel_indices = mask < 127
|
|
||||||
# inpaint_result[original_pixel_indices] = image[:, :, ::-1][
|
|
||||||
# original_pixel_indices
|
|
||||||
# ]
|
|
||||||
return inpaint_result
|
return inpaint_result
|
||||||
|
|
||||||
def set_scheduler(self, config: InpaintRequest):
|
def set_scheduler(self, config: InpaintRequest):
|
||||||
@ -412,7 +399,7 @@ class DiffusionInpaintModel(InpaintModel):
|
|||||||
if config.sd_match_histograms:
|
if config.sd_match_histograms:
|
||||||
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
||||||
|
|
||||||
# if config.sd_mask_blur != 0:
|
if config.use_extender and config.sd_mask_blur != 0:
|
||||||
# k = 2 * config.sd_mask_blur + 1
|
k = 2 * config.sd_mask_blur + 1
|
||||||
# mask = cv2.GaussianBlur(mask, (k, k), 0)
|
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
||||||
return result, image, mask
|
return result, image, mask
|
||||||
|
931
iopaint/model/brushnet/brushnet.py
Normal file
931
iopaint/model/brushnet/brushnet.py
Normal file
@ -0,0 +1,931 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||||
|
from diffusers.utils import BaseOutput, logging
|
||||||
|
from diffusers.models.attention_processor import (
|
||||||
|
ADDED_KV_ATTENTION_PROCESSORS,
|
||||||
|
CROSS_ATTENTION_PROCESSORS,
|
||||||
|
AttentionProcessor,
|
||||||
|
AttnAddedKVProcessor,
|
||||||
|
AttnProcessor,
|
||||||
|
)
|
||||||
|
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, \
|
||||||
|
TimestepEmbedding, Timesteps
|
||||||
|
from diffusers.models.modeling_utils import ModelMixin
|
||||||
|
from diffusers.models.unets.unet_2d_blocks import (
|
||||||
|
CrossAttnDownBlock2D,
|
||||||
|
DownBlock2D, get_down_block, get_up_block,
|
||||||
|
)
|
||||||
|
|
||||||
|
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
|
||||||
|
from .unet_2d_blocks import MidBlock2D
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BrushNetOutput(BaseOutput):
|
||||||
|
"""
|
||||||
|
The output of [`BrushNetModel`].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
up_block_res_samples (`tuple[torch.Tensor]`):
|
||||||
|
A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
|
||||||
|
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
||||||
|
used to condition the original UNet's upsampling activations.
|
||||||
|
down_block_res_samples (`tuple[torch.Tensor]`):
|
||||||
|
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
|
||||||
|
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
|
||||||
|
used to condition the original UNet's downsampling activations.
|
||||||
|
mid_down_block_re_sample (`torch.Tensor`):
|
||||||
|
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
|
||||||
|
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
|
||||||
|
Output can be used to condition the original UNet's middle block activation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
up_block_res_samples: Tuple[torch.Tensor]
|
||||||
|
down_block_res_samples: Tuple[torch.Tensor]
|
||||||
|
mid_block_res_sample: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class BrushNetModel(ModelMixin, ConfigMixin):
|
||||||
|
"""
|
||||||
|
A BrushNet model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
in_channels (`int`, defaults to 4):
|
||||||
|
The number of channels in the input sample.
|
||||||
|
flip_sin_to_cos (`bool`, defaults to `True`):
|
||||||
|
Whether to flip the sin to cos in the time embedding.
|
||||||
|
freq_shift (`int`, defaults to 0):
|
||||||
|
The frequency shift to apply to the time embedding.
|
||||||
|
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||||
|
The tuple of downsample blocks to use.
|
||||||
|
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
||||||
|
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
|
||||||
|
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
|
||||||
|
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
|
||||||
|
The tuple of upsample blocks to use.
|
||||||
|
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
|
||||||
|
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
|
||||||
|
The tuple of output channels for each block.
|
||||||
|
layers_per_block (`int`, defaults to 2):
|
||||||
|
The number of layers per block.
|
||||||
|
downsample_padding (`int`, defaults to 1):
|
||||||
|
The padding to use for the downsampling convolution.
|
||||||
|
mid_block_scale_factor (`float`, defaults to 1):
|
||||||
|
The scale factor to use for the mid block.
|
||||||
|
act_fn (`str`, defaults to "silu"):
|
||||||
|
The activation function to use.
|
||||||
|
norm_num_groups (`int`, *optional*, defaults to 32):
|
||||||
|
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
|
||||||
|
in post-processing.
|
||||||
|
norm_eps (`float`, defaults to 1e-5):
|
||||||
|
The epsilon to use for the normalization.
|
||||||
|
cross_attention_dim (`int`, defaults to 1280):
|
||||||
|
The dimension of the cross attention features.
|
||||||
|
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
|
||||||
|
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
||||||
|
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
|
||||||
|
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
|
||||||
|
encoder_hid_dim (`int`, *optional*, defaults to None):
|
||||||
|
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
|
||||||
|
dimension to `cross_attention_dim`.
|
||||||
|
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
|
||||||
|
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
|
||||||
|
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
|
||||||
|
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
|
||||||
|
The dimension of the attention heads.
|
||||||
|
use_linear_projection (`bool`, defaults to `False`):
|
||||||
|
class_embed_type (`str`, *optional*, defaults to `None`):
|
||||||
|
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
|
||||||
|
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
|
||||||
|
addition_embed_type (`str`, *optional*, defaults to `None`):
|
||||||
|
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
|
||||||
|
"text". "text" will use the `TextTimeEmbedding` layer.
|
||||||
|
num_class_embeds (`int`, *optional*, defaults to 0):
|
||||||
|
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
|
||||||
|
class conditioning with `class_embed_type` equal to `None`.
|
||||||
|
upcast_attention (`bool`, defaults to `False`):
|
||||||
|
resnet_time_scale_shift (`str`, defaults to `"default"`):
|
||||||
|
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
|
||||||
|
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
|
||||||
|
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
|
||||||
|
`class_embed_type="projection"`.
|
||||||
|
brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
|
||||||
|
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
|
||||||
|
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
|
||||||
|
The tuple of output channel for each block in the `conditioning_embedding` layer.
|
||||||
|
global_pool_conditions (`bool`, defaults to `False`):
|
||||||
|
TODO(Patrick) - unused parameter.
|
||||||
|
addition_embed_type_num_heads (`int`, defaults to 64):
|
||||||
|
The number of heads to use for the `TextTimeEmbedding` layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
@register_to_config
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int = 4,
|
||||||
|
conditioning_channels: int = 5,
|
||||||
|
flip_sin_to_cos: bool = True,
|
||||||
|
freq_shift: int = 0,
|
||||||
|
down_block_types: Tuple[str, ...] = (
|
||||||
|
"DownBlock2D",
|
||||||
|
"DownBlock2D",
|
||||||
|
"DownBlock2D",
|
||||||
|
"DownBlock2D",
|
||||||
|
),
|
||||||
|
mid_block_type: Optional[str] = "UNetMidBlock2D",
|
||||||
|
up_block_types: Tuple[str, ...] = (
|
||||||
|
"UpBlock2D",
|
||||||
|
"UpBlock2D",
|
||||||
|
"UpBlock2D",
|
||||||
|
"UpBlock2D",
|
||||||
|
),
|
||||||
|
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||||
|
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||||
|
layers_per_block: int = 2,
|
||||||
|
downsample_padding: int = 1,
|
||||||
|
mid_block_scale_factor: float = 1,
|
||||||
|
act_fn: str = "silu",
|
||||||
|
norm_num_groups: Optional[int] = 32,
|
||||||
|
norm_eps: float = 1e-5,
|
||||||
|
cross_attention_dim: int = 1280,
|
||||||
|
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||||
|
encoder_hid_dim: Optional[int] = None,
|
||||||
|
encoder_hid_dim_type: Optional[str] = None,
|
||||||
|
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||||
|
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||||
|
use_linear_projection: bool = False,
|
||||||
|
class_embed_type: Optional[str] = None,
|
||||||
|
addition_embed_type: Optional[str] = None,
|
||||||
|
addition_time_embed_dim: Optional[int] = None,
|
||||||
|
num_class_embeds: Optional[int] = None,
|
||||||
|
upcast_attention: bool = False,
|
||||||
|
resnet_time_scale_shift: str = "default",
|
||||||
|
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||||
|
brushnet_conditioning_channel_order: str = "rgb",
|
||||||
|
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||||
|
global_pool_conditions: bool = False,
|
||||||
|
addition_embed_type_num_heads: int = 64,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# If `num_attention_heads` is not defined (which is the case for most models)
|
||||||
|
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
|
||||||
|
# The reason for this behavior is to correct for incorrectly named variables that were introduced
|
||||||
|
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
|
||||||
|
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
|
||||||
|
# which is why we correct for the naming here.
|
||||||
|
num_attention_heads = num_attention_heads or attention_head_dim
|
||||||
|
|
||||||
|
# Check inputs
|
||||||
|
if len(down_block_types) != len(up_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(block_out_channels) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
||||||
|
raise ValueError(
|
||||||
|
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(transformer_layers_per_block, int):
|
||||||
|
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
||||||
|
|
||||||
|
# input
|
||||||
|
conv_in_kernel = 3
|
||||||
|
conv_in_padding = (conv_in_kernel - 1) // 2
|
||||||
|
self.conv_in_condition = nn.Conv2d(
|
||||||
|
in_channels + conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel,
|
||||||
|
padding=conv_in_padding
|
||||||
|
)
|
||||||
|
|
||||||
|
# time
|
||||||
|
time_embed_dim = block_out_channels[0] * 4
|
||||||
|
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
|
||||||
|
timestep_input_dim = block_out_channels[0]
|
||||||
|
self.time_embedding = TimestepEmbedding(
|
||||||
|
timestep_input_dim,
|
||||||
|
time_embed_dim,
|
||||||
|
act_fn=act_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
|
||||||
|
encoder_hid_dim_type = "text_proj"
|
||||||
|
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
|
||||||
|
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
|
||||||
|
|
||||||
|
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
|
||||||
|
)
|
||||||
|
|
||||||
|
if encoder_hid_dim_type == "text_proj":
|
||||||
|
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
|
||||||
|
elif encoder_hid_dim_type == "text_image_proj":
|
||||||
|
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||||
|
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||||
|
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
|
||||||
|
self.encoder_hid_proj = TextImageProjection(
|
||||||
|
text_embed_dim=encoder_hid_dim,
|
||||||
|
image_embed_dim=cross_attention_dim,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif encoder_hid_dim_type is not None:
|
||||||
|
raise ValueError(
|
||||||
|
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.encoder_hid_proj = None
|
||||||
|
|
||||||
|
# class embedding
|
||||||
|
if class_embed_type is None and num_class_embeds is not None:
|
||||||
|
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||||
|
elif class_embed_type == "timestep":
|
||||||
|
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||||
|
elif class_embed_type == "identity":
|
||||||
|
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||||
|
elif class_embed_type == "projection":
|
||||||
|
if projection_class_embeddings_input_dim is None:
|
||||||
|
raise ValueError(
|
||||||
|
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
|
||||||
|
)
|
||||||
|
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
|
||||||
|
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
|
||||||
|
# 2. it projects from an arbitrary input dimension.
|
||||||
|
#
|
||||||
|
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
|
||||||
|
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
|
||||||
|
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
|
||||||
|
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||||
|
else:
|
||||||
|
self.class_embedding = None
|
||||||
|
|
||||||
|
if addition_embed_type == "text":
|
||||||
|
if encoder_hid_dim is not None:
|
||||||
|
text_time_embedding_from_dim = encoder_hid_dim
|
||||||
|
else:
|
||||||
|
text_time_embedding_from_dim = cross_attention_dim
|
||||||
|
|
||||||
|
self.add_embedding = TextTimeEmbedding(
|
||||||
|
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
|
||||||
|
)
|
||||||
|
elif addition_embed_type == "text_image":
|
||||||
|
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
|
||||||
|
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
|
||||||
|
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
|
||||||
|
self.add_embedding = TextImageTimeEmbedding(
|
||||||
|
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
|
||||||
|
)
|
||||||
|
elif addition_embed_type == "text_time":
|
||||||
|
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
|
||||||
|
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
||||||
|
|
||||||
|
elif addition_embed_type is not None:
|
||||||
|
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
|
||||||
|
|
||||||
|
self.down_blocks = nn.ModuleList([])
|
||||||
|
self.brushnet_down_blocks = nn.ModuleList([])
|
||||||
|
|
||||||
|
if isinstance(only_cross_attention, bool):
|
||||||
|
only_cross_attention = [only_cross_attention] * len(down_block_types)
|
||||||
|
|
||||||
|
if isinstance(attention_head_dim, int):
|
||||||
|
attention_head_dim = (attention_head_dim,) * len(down_block_types)
|
||||||
|
|
||||||
|
if isinstance(num_attention_heads, int):
|
||||||
|
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
||||||
|
|
||||||
|
# down
|
||||||
|
output_channel = block_out_channels[0]
|
||||||
|
|
||||||
|
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
brushnet_block = zero_module(brushnet_block)
|
||||||
|
self.brushnet_down_blocks.append(brushnet_block)
|
||||||
|
|
||||||
|
for i, down_block_type in enumerate(down_block_types):
|
||||||
|
input_channel = output_channel
|
||||||
|
output_channel = block_out_channels[i]
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
|
||||||
|
down_block = get_down_block(
|
||||||
|
down_block_type,
|
||||||
|
num_layers=layers_per_block,
|
||||||
|
transformer_layers_per_block=transformer_layers_per_block[i],
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
add_downsample=not is_final_block,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
num_attention_heads=num_attention_heads[i],
|
||||||
|
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||||
|
downsample_padding=downsample_padding,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
only_cross_attention=only_cross_attention[i],
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.down_blocks.append(down_block)
|
||||||
|
|
||||||
|
for _ in range(layers_per_block):
|
||||||
|
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
brushnet_block = zero_module(brushnet_block)
|
||||||
|
self.brushnet_down_blocks.append(brushnet_block)
|
||||||
|
|
||||||
|
if not is_final_block:
|
||||||
|
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
brushnet_block = zero_module(brushnet_block)
|
||||||
|
self.brushnet_down_blocks.append(brushnet_block)
|
||||||
|
|
||||||
|
# mid
|
||||||
|
mid_block_channel = block_out_channels[-1]
|
||||||
|
|
||||||
|
brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
|
||||||
|
brushnet_block = zero_module(brushnet_block)
|
||||||
|
self.brushnet_mid_block = brushnet_block
|
||||||
|
|
||||||
|
self.mid_block = MidBlock2D(
|
||||||
|
in_channels=mid_block_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
dropout=0.0,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
output_scale_factor=mid_block_scale_factor,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
)
|
||||||
|
|
||||||
|
# count how many layers upsample the images
|
||||||
|
self.num_upsamplers = 0
|
||||||
|
|
||||||
|
# up
|
||||||
|
reversed_block_out_channels = list(reversed(block_out_channels))
|
||||||
|
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
||||||
|
reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
|
||||||
|
only_cross_attention = list(reversed(only_cross_attention))
|
||||||
|
|
||||||
|
output_channel = reversed_block_out_channels[0]
|
||||||
|
|
||||||
|
self.up_blocks = nn.ModuleList([])
|
||||||
|
self.brushnet_up_blocks = nn.ModuleList([])
|
||||||
|
|
||||||
|
for i, up_block_type in enumerate(up_block_types):
|
||||||
|
is_final_block = i == len(block_out_channels) - 1
|
||||||
|
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
output_channel = reversed_block_out_channels[i]
|
||||||
|
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
||||||
|
|
||||||
|
# add upsample block for all BUT final layer
|
||||||
|
if not is_final_block:
|
||||||
|
add_upsample = True
|
||||||
|
self.num_upsamplers += 1
|
||||||
|
else:
|
||||||
|
add_upsample = False
|
||||||
|
|
||||||
|
up_block = get_up_block(
|
||||||
|
up_block_type,
|
||||||
|
num_layers=layers_per_block + 1,
|
||||||
|
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
||||||
|
in_channels=input_channel,
|
||||||
|
out_channels=output_channel,
|
||||||
|
prev_output_channel=prev_output_channel,
|
||||||
|
temb_channels=time_embed_dim,
|
||||||
|
add_upsample=add_upsample,
|
||||||
|
resnet_eps=norm_eps,
|
||||||
|
resnet_act_fn=act_fn,
|
||||||
|
resolution_idx=i,
|
||||||
|
resnet_groups=norm_num_groups,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
num_attention_heads=reversed_num_attention_heads[i],
|
||||||
|
use_linear_projection=use_linear_projection,
|
||||||
|
only_cross_attention=only_cross_attention[i],
|
||||||
|
upcast_attention=upcast_attention,
|
||||||
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||||
|
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.up_blocks.append(up_block)
|
||||||
|
prev_output_channel = output_channel
|
||||||
|
|
||||||
|
for _ in range(layers_per_block + 1):
|
||||||
|
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
brushnet_block = zero_module(brushnet_block)
|
||||||
|
self.brushnet_up_blocks.append(brushnet_block)
|
||||||
|
|
||||||
|
if not is_final_block:
|
||||||
|
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
|
||||||
|
brushnet_block = zero_module(brushnet_block)
|
||||||
|
self.brushnet_up_blocks.append(brushnet_block)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_unet(
|
||||||
|
cls,
|
||||||
|
unet: UNet2DConditionModel,
|
||||||
|
brushnet_conditioning_channel_order: str = "rgb",
|
||||||
|
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||||
|
load_weights_from_unet: bool = True,
|
||||||
|
conditioning_channels: int = 5,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
unet (`UNet2DConditionModel`):
|
||||||
|
The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
|
||||||
|
where applicable.
|
||||||
|
"""
|
||||||
|
transformer_layers_per_block = (
|
||||||
|
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
|
||||||
|
)
|
||||||
|
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
|
||||||
|
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
|
||||||
|
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
|
||||||
|
addition_time_embed_dim = (
|
||||||
|
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
|
||||||
|
)
|
||||||
|
|
||||||
|
brushnet = cls(
|
||||||
|
in_channels=unet.config.in_channels,
|
||||||
|
conditioning_channels=conditioning_channels,
|
||||||
|
flip_sin_to_cos=unet.config.flip_sin_to_cos,
|
||||||
|
freq_shift=unet.config.freq_shift,
|
||||||
|
down_block_types=['DownBlock2D', 'DownBlock2D', 'DownBlock2D', 'DownBlock2D'],
|
||||||
|
mid_block_type='MidBlock2D',
|
||||||
|
up_block_types=['UpBlock2D', 'UpBlock2D', 'UpBlock2D', 'UpBlock2D'],
|
||||||
|
only_cross_attention=unet.config.only_cross_attention,
|
||||||
|
block_out_channels=unet.config.block_out_channels,
|
||||||
|
layers_per_block=unet.config.layers_per_block,
|
||||||
|
downsample_padding=unet.config.downsample_padding,
|
||||||
|
mid_block_scale_factor=unet.config.mid_block_scale_factor,
|
||||||
|
act_fn=unet.config.act_fn,
|
||||||
|
norm_num_groups=unet.config.norm_num_groups,
|
||||||
|
norm_eps=unet.config.norm_eps,
|
||||||
|
cross_attention_dim=unet.config.cross_attention_dim,
|
||||||
|
transformer_layers_per_block=transformer_layers_per_block,
|
||||||
|
encoder_hid_dim=encoder_hid_dim,
|
||||||
|
encoder_hid_dim_type=encoder_hid_dim_type,
|
||||||
|
attention_head_dim=unet.config.attention_head_dim,
|
||||||
|
num_attention_heads=unet.config.num_attention_heads,
|
||||||
|
use_linear_projection=unet.config.use_linear_projection,
|
||||||
|
class_embed_type=unet.config.class_embed_type,
|
||||||
|
addition_embed_type=addition_embed_type,
|
||||||
|
addition_time_embed_dim=addition_time_embed_dim,
|
||||||
|
num_class_embeds=unet.config.num_class_embeds,
|
||||||
|
upcast_attention=unet.config.upcast_attention,
|
||||||
|
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
|
||||||
|
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
|
||||||
|
brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
|
||||||
|
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||||
|
)
|
||||||
|
|
||||||
|
if load_weights_from_unet:
|
||||||
|
conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
|
||||||
|
conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
|
||||||
|
conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
|
||||||
|
brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
|
||||||
|
brushnet.conv_in_condition.bias = unet.conv_in.bias
|
||||||
|
|
||||||
|
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
|
||||||
|
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
|
||||||
|
|
||||||
|
if brushnet.class_embedding:
|
||||||
|
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
|
||||||
|
|
||||||
|
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
|
||||||
|
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
|
||||||
|
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
|
||||||
|
|
||||||
|
return brushnet
|
||||||
|
|
||||||
|
@property
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||||
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||||
|
r"""
|
||||||
|
Returns:
|
||||||
|
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||||
|
indexed by its weight name.
|
||||||
|
"""
|
||||||
|
# set recursively
|
||||||
|
processors = {}
|
||||||
|
|
||||||
|
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||||
|
if hasattr(module, "get_processor"):
|
||||||
|
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_add_processors(name, module, processors)
|
||||||
|
|
||||||
|
return processors
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||||
|
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||||
|
r"""
|
||||||
|
Sets the attention processor to use to compute attention.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||||
|
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||||
|
for **all** `Attention` layers.
|
||||||
|
|
||||||
|
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||||
|
processor. This is strongly recommended when setting trainable attention processors.
|
||||||
|
|
||||||
|
"""
|
||||||
|
count = len(self.attn_processors.keys())
|
||||||
|
|
||||||
|
if isinstance(processor, dict) and len(processor) != count:
|
||||||
|
raise ValueError(
|
||||||
|
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||||
|
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||||
|
)
|
||||||
|
|
||||||
|
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||||
|
if hasattr(module, "set_processor"):
|
||||||
|
if not isinstance(processor, dict):
|
||||||
|
module.set_processor(processor)
|
||||||
|
else:
|
||||||
|
module.set_processor(processor.pop(f"{name}.processor"))
|
||||||
|
|
||||||
|
for sub_name, child in module.named_children():
|
||||||
|
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||||
|
|
||||||
|
for name, module in self.named_children():
|
||||||
|
fn_recursive_attn_processor(name, module, processor)
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||||
|
def set_default_attn_processor(self):
|
||||||
|
"""
|
||||||
|
Disables custom attention processors and sets the default attention implementation.
|
||||||
|
"""
|
||||||
|
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||||
|
processor = AttnAddedKVProcessor()
|
||||||
|
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||||
|
processor = AttnProcessor()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.set_attn_processor(processor)
|
||||||
|
|
||||||
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||||
|
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
|
||||||
|
r"""
|
||||||
|
Enable sliced attention computation.
|
||||||
|
|
||||||
|
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
|
||||||
|
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||||
|
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
|
||||||
|
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
|
||||||
|
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||||
|
must be a multiple of `slice_size`.
|
||||||
|
"""
|
||||||
|
sliceable_head_dims = []
|
||||||
|
|
||||||
|
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
|
||||||
|
if hasattr(module, "set_attention_slice"):
|
||||||
|
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_retrieve_sliceable_dims(child)
|
||||||
|
|
||||||
|
# retrieve number of attention layers
|
||||||
|
for module in self.children():
|
||||||
|
fn_recursive_retrieve_sliceable_dims(module)
|
||||||
|
|
||||||
|
num_sliceable_layers = len(sliceable_head_dims)
|
||||||
|
|
||||||
|
if slice_size == "auto":
|
||||||
|
# half the attention head size is usually a good trade-off between
|
||||||
|
# speed and memory
|
||||||
|
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||||
|
elif slice_size == "max":
|
||||||
|
# make smallest slice possible
|
||||||
|
slice_size = num_sliceable_layers * [1]
|
||||||
|
|
||||||
|
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||||
|
|
||||||
|
if len(slice_size) != len(sliceable_head_dims):
|
||||||
|
raise ValueError(
|
||||||
|
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||||
|
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(len(slice_size)):
|
||||||
|
size = slice_size[i]
|
||||||
|
dim = sliceable_head_dims[i]
|
||||||
|
if size is not None and size > dim:
|
||||||
|
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||||
|
|
||||||
|
# Recursively walk through all the children.
|
||||||
|
# Any children which exposes the set_attention_slice method
|
||||||
|
# gets the message
|
||||||
|
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||||
|
if hasattr(module, "set_attention_slice"):
|
||||||
|
module.set_attention_slice(slice_size.pop())
|
||||||
|
|
||||||
|
for child in module.children():
|
||||||
|
fn_recursive_set_attention_slice(child, slice_size)
|
||||||
|
|
||||||
|
reversed_slice_size = list(reversed(slice_size))
|
||||||
|
for module in self.children():
|
||||||
|
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
|
||||||
|
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
|
||||||
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
timestep: Union[torch.Tensor, float, int],
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
brushnet_cond: torch.FloatTensor,
|
||||||
|
conditioning_scale: float = 1.0,
|
||||||
|
class_labels: Optional[torch.Tensor] = None,
|
||||||
|
timestep_cond: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
guess_mode: bool = False,
|
||||||
|
return_dict: bool = True,
|
||||||
|
) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
|
||||||
|
"""
|
||||||
|
The [`BrushNetModel`] forward method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`):
|
||||||
|
The noisy input tensor.
|
||||||
|
timestep (`Union[torch.Tensor, float, int]`):
|
||||||
|
The number of timesteps to denoise an input.
|
||||||
|
encoder_hidden_states (`torch.Tensor`):
|
||||||
|
The encoder hidden states.
|
||||||
|
brushnet_cond (`torch.FloatTensor`):
|
||||||
|
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
||||||
|
conditioning_scale (`float`, defaults to `1.0`):
|
||||||
|
The scale factor for BrushNet outputs.
|
||||||
|
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||||
|
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
|
||||||
|
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
|
||||||
|
embeddings.
|
||||||
|
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||||
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||||
|
negative values to the attention scores corresponding to "discard" tokens.
|
||||||
|
added_cond_kwargs (`dict`):
|
||||||
|
Additional conditions for the Stable Diffusion XL UNet.
|
||||||
|
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
|
||||||
|
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
|
||||||
|
guess_mode (`bool`, defaults to `False`):
|
||||||
|
In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
|
||||||
|
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||||
|
return_dict (`bool`, defaults to `True`):
|
||||||
|
Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.brushnet.BrushNetOutput`] **or** `tuple`:
|
||||||
|
If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
|
||||||
|
returned where the first element is the sample tensor.
|
||||||
|
"""
|
||||||
|
# check channel order
|
||||||
|
channel_order = self.config.brushnet_conditioning_channel_order
|
||||||
|
|
||||||
|
if channel_order == "rgb":
|
||||||
|
# in rgb order by default
|
||||||
|
...
|
||||||
|
elif channel_order == "bgr":
|
||||||
|
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
|
||||||
|
else:
|
||||||
|
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
|
||||||
|
|
||||||
|
# prepare attention_mask
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||||
|
attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# 1. time
|
||||||
|
timesteps = timestep
|
||||||
|
if not torch.is_tensor(timesteps):
|
||||||
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||||
|
# This would be a good case for the `match` statement (Python 3.10+)
|
||||||
|
is_mps = sample.device.type == "mps"
|
||||||
|
if isinstance(timestep, float):
|
||||||
|
dtype = torch.float32 if is_mps else torch.float64
|
||||||
|
else:
|
||||||
|
dtype = torch.int32 if is_mps else torch.int64
|
||||||
|
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
||||||
|
elif len(timesteps.shape) == 0:
|
||||||
|
timesteps = timesteps[None].to(sample.device)
|
||||||
|
|
||||||
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
timesteps = timesteps.expand(sample.shape[0])
|
||||||
|
|
||||||
|
t_emb = self.time_proj(timesteps)
|
||||||
|
|
||||||
|
# timesteps does not contain any weights and will always return f32 tensors
|
||||||
|
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||||
|
# there might be better ways to encapsulate this.
|
||||||
|
t_emb = t_emb.to(dtype=sample.dtype)
|
||||||
|
|
||||||
|
emb = self.time_embedding(t_emb, timestep_cond)
|
||||||
|
aug_emb = None
|
||||||
|
|
||||||
|
if self.class_embedding is not None:
|
||||||
|
if class_labels is None:
|
||||||
|
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||||
|
|
||||||
|
if self.config.class_embed_type == "timestep":
|
||||||
|
class_labels = self.time_proj(class_labels)
|
||||||
|
|
||||||
|
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||||
|
emb = emb + class_emb
|
||||||
|
|
||||||
|
if self.config.addition_embed_type is not None:
|
||||||
|
if self.config.addition_embed_type == "text":
|
||||||
|
aug_emb = self.add_embedding(encoder_hidden_states)
|
||||||
|
|
||||||
|
elif self.config.addition_embed_type == "text_time":
|
||||||
|
if "text_embeds" not in added_cond_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
||||||
|
)
|
||||||
|
text_embeds = added_cond_kwargs.get("text_embeds")
|
||||||
|
if "time_ids" not in added_cond_kwargs:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
||||||
|
)
|
||||||
|
time_ids = added_cond_kwargs.get("time_ids")
|
||||||
|
time_embeds = self.add_time_proj(time_ids.flatten())
|
||||||
|
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
||||||
|
|
||||||
|
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
||||||
|
add_embeds = add_embeds.to(emb.dtype)
|
||||||
|
aug_emb = self.add_embedding(add_embeds)
|
||||||
|
|
||||||
|
emb = emb + aug_emb if aug_emb is not None else emb
|
||||||
|
|
||||||
|
# 2. pre-process
|
||||||
|
brushnet_cond = torch.concat([sample, brushnet_cond], 1)
|
||||||
|
sample = self.conv_in_condition(brushnet_cond)
|
||||||
|
|
||||||
|
# 3. down
|
||||||
|
down_block_res_samples = (sample,)
|
||||||
|
for downsample_block in self.down_blocks:
|
||||||
|
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||||
|
sample, res_samples = downsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||||
|
|
||||||
|
down_block_res_samples += res_samples
|
||||||
|
|
||||||
|
# 4. PaintingNet down blocks
|
||||||
|
brushnet_down_block_res_samples = ()
|
||||||
|
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
|
||||||
|
down_block_res_sample = brushnet_down_block(down_block_res_sample)
|
||||||
|
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
|
||||||
|
|
||||||
|
# 5. mid
|
||||||
|
if self.mid_block is not None:
|
||||||
|
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||||
|
sample = self.mid_block(
|
||||||
|
sample,
|
||||||
|
emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample = self.mid_block(sample, emb)
|
||||||
|
|
||||||
|
# 6. BrushNet mid blocks
|
||||||
|
brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
|
||||||
|
|
||||||
|
# 7. up
|
||||||
|
up_block_res_samples = ()
|
||||||
|
for i, upsample_block in enumerate(self.up_blocks):
|
||||||
|
is_final_block = i == len(self.up_blocks) - 1
|
||||||
|
|
||||||
|
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
||||||
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||||
|
|
||||||
|
# if we have not reached the final block and need to forward the
|
||||||
|
# upsample size, we do it here
|
||||||
|
if not is_final_block:
|
||||||
|
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||||
|
|
||||||
|
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||||
|
sample, up_res_samples = upsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
return_res_samples=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample, up_res_samples = upsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
return_res_samples=True
|
||||||
|
)
|
||||||
|
|
||||||
|
up_block_res_samples += up_res_samples
|
||||||
|
|
||||||
|
# 8. BrushNet up blocks
|
||||||
|
brushnet_up_block_res_samples = ()
|
||||||
|
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
|
||||||
|
up_block_res_sample = brushnet_up_block(up_block_res_sample)
|
||||||
|
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
|
||||||
|
|
||||||
|
# 6. scaling
|
||||||
|
if guess_mode and not self.config.global_pool_conditions:
|
||||||
|
scales = torch.logspace(-1, 0,
|
||||||
|
len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
|
||||||
|
device=sample.device) # 0.1 to 1.0
|
||||||
|
scales = scales * conditioning_scale
|
||||||
|
|
||||||
|
brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples,
|
||||||
|
scales[:len(
|
||||||
|
brushnet_down_block_res_samples)])]
|
||||||
|
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
|
||||||
|
brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples,
|
||||||
|
scales[
|
||||||
|
len(brushnet_down_block_res_samples) + 1:])]
|
||||||
|
else:
|
||||||
|
brushnet_down_block_res_samples = [sample * conditioning_scale for sample in
|
||||||
|
brushnet_down_block_res_samples]
|
||||||
|
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
|
||||||
|
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
|
||||||
|
|
||||||
|
if self.config.global_pool_conditions:
|
||||||
|
brushnet_down_block_res_samples = [
|
||||||
|
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
|
||||||
|
]
|
||||||
|
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
|
||||||
|
brushnet_up_block_res_samples = [
|
||||||
|
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
|
||||||
|
]
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
|
||||||
|
|
||||||
|
return BrushNetOutput(
|
||||||
|
down_block_res_samples=brushnet_down_block_res_samples,
|
||||||
|
mid_block_res_sample=brushnet_mid_block_res_sample,
|
||||||
|
up_block_res_samples=brushnet_up_block_res_samples
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def zero_module(module):
|
||||||
|
for p in module.parameters():
|
||||||
|
nn.init.zeros_(p)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
BrushNetModel.from_pretrained("/Users/cwq/data/models/brushnet/brushnet_random_mask", variant='fp16',
|
||||||
|
use_safetensors=True)
|
322
iopaint/model/brushnet/brushnet_unet_forward.py
Normal file
322
iopaint/model/brushnet/brushnet_unet_forward.py
Normal file
@ -0,0 +1,322 @@
|
|||||||
|
from typing import Union, Optional, Dict, Any, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||||
|
from diffusers.utils import USE_PEFT_BACKEND, unscale_lora_layers, deprecate, scale_lora_layers
|
||||||
|
|
||||||
|
|
||||||
|
def brushnet_unet_forward(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
timestep: Union[torch.Tensor, float, int],
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
class_labels: Optional[torch.Tensor] = None,
|
||||||
|
timestep_cond: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||||
|
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||||
|
r"""
|
||||||
|
The [`UNet2DConditionModel`] forward method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`):
|
||||||
|
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
||||||
|
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||||
|
encoder_hidden_states (`torch.FloatTensor`):
|
||||||
|
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
||||||
|
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||||
|
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
||||||
|
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
||||||
|
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||||
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||||
|
negative values to the attention scores corresponding to "discard" tokens.
|
||||||
|
cross_attention_kwargs (`dict`, *optional*):
|
||||||
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||||
|
`self.processor` in
|
||||||
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||||
|
added_cond_kwargs: (`dict`, *optional*):
|
||||||
|
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
||||||
|
are passed along to the UNet blocks.
|
||||||
|
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
||||||
|
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
||||||
|
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
||||||
|
A tensor that if specified is added to the residual of the middle unet block.
|
||||||
|
encoder_attention_mask (`torch.Tensor`):
|
||||||
|
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
||||||
|
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
||||||
|
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||||
|
tuple.
|
||||||
|
cross_attention_kwargs (`dict`, *optional*):
|
||||||
|
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
||||||
|
added_cond_kwargs: (`dict`, *optional*):
|
||||||
|
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
||||||
|
are passed along to the UNet blocks.
|
||||||
|
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||||
|
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
|
||||||
|
example from ControlNet side model(s)
|
||||||
|
mid_block_additional_residual (`torch.Tensor`, *optional*):
|
||||||
|
additional residual to be added to UNet mid block output, for example from ControlNet side model
|
||||||
|
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||||
|
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||||
|
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
||||||
|
a `tuple` is returned where the first element is the sample tensor.
|
||||||
|
"""
|
||||||
|
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||||
|
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
||||||
|
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||||
|
# on the fly if necessary.
|
||||||
|
default_overall_up_factor = 2 ** self.num_upsamplers
|
||||||
|
|
||||||
|
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||||
|
forward_upsample_size = False
|
||||||
|
upsample_size = None
|
||||||
|
|
||||||
|
for dim in sample.shape[-2:]:
|
||||||
|
if dim % default_overall_up_factor != 0:
|
||||||
|
# Forward upsample size to force interpolation output size.
|
||||||
|
forward_upsample_size = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
||||||
|
# expects mask of shape:
|
||||||
|
# [batch, key_tokens]
|
||||||
|
# adds singleton query_tokens dimension:
|
||||||
|
# [batch, 1, key_tokens]
|
||||||
|
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||||
|
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||||
|
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||||
|
if attention_mask is not None:
|
||||||
|
# assume that mask is expressed as:
|
||||||
|
# (1 = keep, 0 = discard)
|
||||||
|
# convert mask into a bias that can be added to attention scores:
|
||||||
|
# (keep = +0, discard = -10000.0)
|
||||||
|
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||||
|
attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||||
|
if encoder_attention_mask is not None:
|
||||||
|
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
||||||
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# 0. center input if necessary
|
||||||
|
if self.config.center_input_sample:
|
||||||
|
sample = 2 * sample - 1.0
|
||||||
|
|
||||||
|
# 1. time
|
||||||
|
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
|
||||||
|
emb = self.time_embedding(t_emb, timestep_cond)
|
||||||
|
aug_emb = None
|
||||||
|
|
||||||
|
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
|
||||||
|
if class_emb is not None:
|
||||||
|
if self.config.class_embeddings_concat:
|
||||||
|
emb = torch.cat([emb, class_emb], dim=-1)
|
||||||
|
else:
|
||||||
|
emb = emb + class_emb
|
||||||
|
|
||||||
|
aug_emb = self.get_aug_embed(
|
||||||
|
emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
||||||
|
)
|
||||||
|
if self.config.addition_embed_type == "image_hint":
|
||||||
|
aug_emb, hint = aug_emb
|
||||||
|
sample = torch.cat([sample, hint], dim=1)
|
||||||
|
|
||||||
|
emb = emb + aug_emb if aug_emb is not None else emb
|
||||||
|
|
||||||
|
if self.time_embed_act is not None:
|
||||||
|
emb = self.time_embed_act(emb)
|
||||||
|
|
||||||
|
encoder_hidden_states = self.process_encoder_hidden_states(
|
||||||
|
encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. pre-process
|
||||||
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
|
# 2.5 GLIGEN position net
|
||||||
|
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
||||||
|
cross_attention_kwargs = cross_attention_kwargs.copy()
|
||||||
|
gligen_args = cross_attention_kwargs.pop("gligen")
|
||||||
|
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
||||||
|
|
||||||
|
# 3. down
|
||||||
|
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||||
|
if USE_PEFT_BACKEND:
|
||||||
|
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||||
|
scale_lora_layers(self, lora_scale)
|
||||||
|
|
||||||
|
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
||||||
|
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
||||||
|
is_adapter = down_intrablock_additional_residuals is not None
|
||||||
|
# maintain backward compatibility for legacy usage, where
|
||||||
|
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
||||||
|
# but can only use one or the other
|
||||||
|
is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None
|
||||||
|
if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
|
||||||
|
deprecate(
|
||||||
|
"T2I should not use down_block_additional_residuals",
|
||||||
|
"1.3.0",
|
||||||
|
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
||||||
|
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
||||||
|
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
||||||
|
standard_warn=False,
|
||||||
|
)
|
||||||
|
down_intrablock_additional_residuals = down_block_additional_residuals
|
||||||
|
is_adapter = True
|
||||||
|
|
||||||
|
down_block_res_samples = (sample,)
|
||||||
|
|
||||||
|
if is_brushnet:
|
||||||
|
sample = sample + down_block_add_samples.pop(0)
|
||||||
|
|
||||||
|
for downsample_block in self.down_blocks:
|
||||||
|
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
||||||
|
# For t2i-adapter CrossAttnDownBlock2D
|
||||||
|
additional_residuals = {}
|
||||||
|
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||||
|
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
||||||
|
|
||||||
|
if is_brushnet and len(down_block_add_samples) > 0:
|
||||||
|
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
|
||||||
|
for _ in range(
|
||||||
|
len(downsample_block.resnets) + (downsample_block.downsamplers != None))]
|
||||||
|
|
||||||
|
sample, res_samples = downsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
**additional_residuals,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
additional_residuals = {}
|
||||||
|
if is_brushnet and len(down_block_add_samples) > 0:
|
||||||
|
additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
|
||||||
|
for _ in range(
|
||||||
|
len(downsample_block.resnets) + (downsample_block.downsamplers != None))]
|
||||||
|
|
||||||
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale,
|
||||||
|
**additional_residuals)
|
||||||
|
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||||
|
sample += down_intrablock_additional_residuals.pop(0)
|
||||||
|
|
||||||
|
down_block_res_samples += res_samples
|
||||||
|
|
||||||
|
if is_controlnet:
|
||||||
|
new_down_block_res_samples = ()
|
||||||
|
|
||||||
|
for down_block_res_sample, down_block_additional_residual in zip(
|
||||||
|
down_block_res_samples, down_block_additional_residuals
|
||||||
|
):
|
||||||
|
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
||||||
|
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
||||||
|
|
||||||
|
down_block_res_samples = new_down_block_res_samples
|
||||||
|
|
||||||
|
# 4. mid
|
||||||
|
if self.mid_block is not None:
|
||||||
|
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
||||||
|
sample = self.mid_block(
|
||||||
|
sample,
|
||||||
|
emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample = self.mid_block(sample, emb)
|
||||||
|
|
||||||
|
# To support T2I-Adapter-XL
|
||||||
|
if (
|
||||||
|
is_adapter
|
||||||
|
and len(down_intrablock_additional_residuals) > 0
|
||||||
|
and sample.shape == down_intrablock_additional_residuals[0].shape
|
||||||
|
):
|
||||||
|
sample += down_intrablock_additional_residuals.pop(0)
|
||||||
|
|
||||||
|
if is_controlnet:
|
||||||
|
sample = sample + mid_block_additional_residual
|
||||||
|
|
||||||
|
if is_brushnet:
|
||||||
|
sample = sample + mid_block_add_sample
|
||||||
|
|
||||||
|
# 5. up
|
||||||
|
for i, upsample_block in enumerate(self.up_blocks):
|
||||||
|
is_final_block = i == len(self.up_blocks) - 1
|
||||||
|
|
||||||
|
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
||||||
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||||
|
|
||||||
|
# if we have not reached the final block and need to forward the
|
||||||
|
# upsample size, we do it here
|
||||||
|
if not is_final_block and forward_upsample_size:
|
||||||
|
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||||
|
|
||||||
|
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
||||||
|
additional_residuals = {}
|
||||||
|
if is_brushnet and len(up_block_add_samples) > 0:
|
||||||
|
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
|
||||||
|
for _ in range(
|
||||||
|
len(upsample_block.resnets) + (upsample_block.upsamplers != None))]
|
||||||
|
|
||||||
|
sample = upsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
**additional_residuals,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
additional_residuals = {}
|
||||||
|
if is_brushnet and len(up_block_add_samples) > 0:
|
||||||
|
additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
|
||||||
|
for _ in range(
|
||||||
|
len(upsample_block.resnets) + (upsample_block.upsamplers != None))]
|
||||||
|
|
||||||
|
sample = upsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
scale=lora_scale,
|
||||||
|
**additional_residuals,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. post-process
|
||||||
|
if self.conv_norm_out:
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
sample = self.conv_out(sample)
|
||||||
|
|
||||||
|
if USE_PEFT_BACKEND:
|
||||||
|
# remove `lora_scale` from each PEFT layer
|
||||||
|
unscale_lora_layers(self, lora_scale)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (sample,)
|
||||||
|
|
||||||
|
return UNet2DConditionOutput(sample=sample)
|
157
iopaint/model/brushnet/brushnet_wrapper.py
Normal file
157
iopaint/model/brushnet/brushnet_wrapper.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
import PIL.Image
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ..base import DiffusionInpaintModel
|
||||||
|
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||||
|
from ..original_sd_configs import get_config_files
|
||||||
|
from ..utils import (
|
||||||
|
handle_from_pretrained_exceptions,
|
||||||
|
get_torch_dtype,
|
||||||
|
enable_low_mem,
|
||||||
|
is_local_files_only,
|
||||||
|
)
|
||||||
|
from .brushnet import BrushNetModel
|
||||||
|
from .brushnet_unet_forward import brushnet_unet_forward
|
||||||
|
from .unet_2d_blocks import CrossAttnDownBlock2D_forward, DownBlock2D_forward, CrossAttnUpBlock2D_forward, \
|
||||||
|
UpBlock2D_forward
|
||||||
|
from ...schema import InpaintRequest, ModelType
|
||||||
|
|
||||||
|
|
||||||
|
class BrushNetWrapper(DiffusionInpaintModel):
|
||||||
|
pad_mod = 8
|
||||||
|
min_size = 512
|
||||||
|
|
||||||
|
def init_model(self, device: torch.device, **kwargs):
|
||||||
|
from .pipeline_brushnet import StableDiffusionBrushNetPipeline
|
||||||
|
self.model_info = kwargs["model_info"]
|
||||||
|
self.brushnet_method = kwargs["brushnet_method"]
|
||||||
|
|
||||||
|
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
|
||||||
|
model_kwargs = {
|
||||||
|
**kwargs.get("pipe_components", {}),
|
||||||
|
"local_files_only": is_local_files_only(**kwargs),
|
||||||
|
}
|
||||||
|
self.local_files_only = model_kwargs["local_files_only"]
|
||||||
|
|
||||||
|
disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
|
||||||
|
"cpu_offload", False
|
||||||
|
)
|
||||||
|
if disable_nsfw_checker:
|
||||||
|
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||||
|
model_kwargs.update(
|
||||||
|
dict(
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Loading BrushNet model from {self.brushnet_method}")
|
||||||
|
brushnet = BrushNetModel.from_pretrained(self.brushnet_method, torch_dtype=torch_dtype)
|
||||||
|
|
||||||
|
if self.model_info.is_single_file_diffusers:
|
||||||
|
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||||
|
model_kwargs["num_in_channels"] = 4
|
||||||
|
else:
|
||||||
|
model_kwargs["num_in_channels"] = 9
|
||||||
|
|
||||||
|
self.model = StableDiffusionBrushNetPipeline.from_single_file(
|
||||||
|
self.model_id_or_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
load_safety_checker=not disable_nsfw_checker,
|
||||||
|
original_config_file=get_config_files()['v1'],
|
||||||
|
brushnet=brushnet,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.model = handle_from_pretrained_exceptions(
|
||||||
|
StableDiffusionBrushNetPipeline.from_pretrained,
|
||||||
|
pretrained_model_name_or_path=self.model_id_or_path,
|
||||||
|
variant="fp16",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
brushnet=brushnet,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
enable_low_mem(self.model, kwargs.get("low_mem", False))
|
||||||
|
|
||||||
|
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||||
|
logger.info("Enable sequential cpu offload")
|
||||||
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
|
else:
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
if kwargs["sd_cpu_textencoder"]:
|
||||||
|
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||||
|
self.model.text_encoder = CPUTextEncoderWrapper(
|
||||||
|
self.model.text_encoder, torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
|
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
|
||||||
|
self.model.unet.forward = brushnet_unet_forward.__get__(self.model.unet, self.model.unet.__class__)
|
||||||
|
|
||||||
|
for down_block in self.model.brushnet.down_blocks:
|
||||||
|
down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__)
|
||||||
|
for up_block in self.model.brushnet.up_blocks:
|
||||||
|
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
|
||||||
|
|
||||||
|
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
|
||||||
|
for down_block in self.model.unet.down_blocks:
|
||||||
|
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
|
||||||
|
down_block.forward = CrossAttnDownBlock2D_forward.__get__(down_block, down_block.__class__)
|
||||||
|
else:
|
||||||
|
down_block.forward = DownBlock2D_forward.__get__(down_block, down_block.__class__)
|
||||||
|
|
||||||
|
for up_block in self.model.unet.up_blocks:
|
||||||
|
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
||||||
|
up_block.forward = CrossAttnUpBlock2D_forward.__get__(up_block, up_block.__class__)
|
||||||
|
else:
|
||||||
|
up_block.forward = UpBlock2D_forward.__get__(up_block, up_block.__class__)
|
||||||
|
|
||||||
|
def switch_brushnet_method(self, new_method: str):
|
||||||
|
self.brushnet_method = new_method
|
||||||
|
brushnet = BrushNetModel.from_pretrained(
|
||||||
|
new_method,
|
||||||
|
resume_download=True,
|
||||||
|
local_files_only=self.local_files_only,
|
||||||
|
torch_dtype=self.torch_dtype,
|
||||||
|
).to(self.model.device)
|
||||||
|
self.model.brushnet = brushnet
|
||||||
|
|
||||||
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
|
"""Input image and output image have same size
|
||||||
|
image: [H, W, C] RGB
|
||||||
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
|
return: BGR IMAGE
|
||||||
|
"""
|
||||||
|
self.set_scheduler(config)
|
||||||
|
|
||||||
|
img_h, img_w = image.shape[:2]
|
||||||
|
normalized_mask = mask[:, :].astype("float32") / 255.0
|
||||||
|
image = image * (1 - normalized_mask)
|
||||||
|
image = image.astype(np.uint8)
|
||||||
|
output = self.model(
|
||||||
|
image=PIL.Image.fromarray(image),
|
||||||
|
prompt=config.prompt,
|
||||||
|
negative_prompt=config.negative_prompt,
|
||||||
|
mask=PIL.Image.fromarray(mask[:, :, -1], mode="L").convert("RGB"),
|
||||||
|
num_inference_steps=config.sd_steps,
|
||||||
|
# strength=config.sd_strength,
|
||||||
|
guidance_scale=config.sd_guidance_scale,
|
||||||
|
output_type="np",
|
||||||
|
callback_on_step_end=self.callback,
|
||||||
|
height=img_h,
|
||||||
|
width=img_w,
|
||||||
|
generator=torch.manual_seed(config.sd_seed),
|
||||||
|
brushnet_conditioning_scale=config.brushnet_conditioning_scale,
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
output = (output * 255).round().astype("uint8")
|
||||||
|
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||||
|
return output
|
1279
iopaint/model/brushnet/pipeline_brushnet.py
Normal file
1279
iopaint/model/brushnet/pipeline_brushnet.py
Normal file
File diff suppressed because it is too large
Load Diff
388
iopaint/model/brushnet/unet_2d_blocks.py
Normal file
388
iopaint/model/brushnet/unet_2d_blocks.py
Normal file
@ -0,0 +1,388 @@
|
|||||||
|
from typing import Dict, Any, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers.models.resnet import ResnetBlock2D
|
||||||
|
from diffusers.utils import is_torch_version
|
||||||
|
from diffusers.utils.torch_utils import apply_freeu
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
class MidBlock2D(nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_channels: int,
|
||||||
|
temb_channels: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
num_layers: int = 1,
|
||||||
|
resnet_eps: float = 1e-6,
|
||||||
|
resnet_time_scale_shift: str = "default",
|
||||||
|
resnet_act_fn: str = "swish",
|
||||||
|
resnet_groups: int = 32,
|
||||||
|
resnet_pre_norm: bool = True,
|
||||||
|
output_scale_factor: float = 1.0,
|
||||||
|
use_linear_projection: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.has_cross_attention = False
|
||||||
|
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||||
|
|
||||||
|
# there is always at least one resnet
|
||||||
|
resnets = [
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
for i in range(num_layers):
|
||||||
|
resnets.append(
|
||||||
|
ResnetBlock2D(
|
||||||
|
in_channels=in_channels,
|
||||||
|
out_channels=in_channels,
|
||||||
|
temb_channels=temb_channels,
|
||||||
|
eps=resnet_eps,
|
||||||
|
groups=resnet_groups,
|
||||||
|
dropout=dropout,
|
||||||
|
time_embedding_norm=resnet_time_scale_shift,
|
||||||
|
non_linearity=resnet_act_fn,
|
||||||
|
output_scale_factor=output_scale_factor,
|
||||||
|
pre_norm=resnet_pre_norm,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.resnets = nn.ModuleList(resnets)
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
lora_scale = 1.0
|
||||||
|
hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
|
||||||
|
for resnet in self.resnets[1:]:
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module, return_dict=None):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
if return_dict is not None:
|
||||||
|
return module(*inputs, return_dict=return_dict)
|
||||||
|
else:
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet),
|
||||||
|
hidden_states,
|
||||||
|
temb,
|
||||||
|
**ckpt_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def DownBlock2D_forward(
|
||||||
|
self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0,
|
||||||
|
down_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||||
|
output_states = ()
|
||||||
|
|
||||||
|
for resnet in self.resnets:
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
if is_torch_version(">=", "1.11.0"):
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet), hidden_states, temb
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||||
|
|
||||||
|
if down_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + down_block_add_samples.pop(0)
|
||||||
|
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states, scale=scale)
|
||||||
|
|
||||||
|
if down_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + down_block_add_samples.pop(0) # todo: add before or after
|
||||||
|
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
|
||||||
|
return hidden_states, output_states
|
||||||
|
|
||||||
|
|
||||||
|
def CrossAttnDownBlock2D_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
additional_residuals: Optional[torch.FloatTensor] = None,
|
||||||
|
down_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||||
|
output_states = ()
|
||||||
|
|
||||||
|
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||||
|
|
||||||
|
blocks = list(zip(self.resnets, self.attentions))
|
||||||
|
|
||||||
|
for i, (resnet, attn) in enumerate(blocks):
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module, return_dict=None):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
if return_dict is not None:
|
||||||
|
return module(*inputs, return_dict=return_dict)
|
||||||
|
else:
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet),
|
||||||
|
hidden_states,
|
||||||
|
temb,
|
||||||
|
**ckpt_kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = attn(
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||||
|
hidden_states = attn(
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
||||||
|
if i == len(blocks) - 1 and additional_residuals is not None:
|
||||||
|
hidden_states = hidden_states + additional_residuals
|
||||||
|
|
||||||
|
if down_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + down_block_add_samples.pop(0)
|
||||||
|
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states, scale=lora_scale)
|
||||||
|
|
||||||
|
if down_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + down_block_add_samples.pop(0) # todo: add before or after
|
||||||
|
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
|
||||||
|
return hidden_states, output_states
|
||||||
|
|
||||||
|
|
||||||
|
def CrossAttnUpBlock2D_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||||
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
upsample_size: Optional[int] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
return_res_samples: Optional[bool] = False,
|
||||||
|
up_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
||||||
|
is_freeu_enabled = (
|
||||||
|
getattr(self, "s1", None)
|
||||||
|
and getattr(self, "s2", None)
|
||||||
|
and getattr(self, "b1", None)
|
||||||
|
and getattr(self, "b2", None)
|
||||||
|
)
|
||||||
|
if return_res_samples:
|
||||||
|
output_states = ()
|
||||||
|
|
||||||
|
for resnet, attn in zip(self.resnets, self.attentions):
|
||||||
|
# pop res hidden states
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
|
||||||
|
# FreeU: Only operate on the first two stages
|
||||||
|
if is_freeu_enabled:
|
||||||
|
hidden_states, res_hidden_states = apply_freeu(
|
||||||
|
self.resolution_idx,
|
||||||
|
hidden_states,
|
||||||
|
res_hidden_states,
|
||||||
|
s1=self.s1,
|
||||||
|
s2=self.s2,
|
||||||
|
b1=self.b1,
|
||||||
|
b2=self.b2,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module, return_dict=None):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
if return_dict is not None:
|
||||||
|
return module(*inputs, return_dict=return_dict)
|
||||||
|
else:
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet),
|
||||||
|
hidden_states,
|
||||||
|
temb,
|
||||||
|
**ckpt_kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = attn(
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||||
|
hidden_states = attn(
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
if return_res_samples:
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
if up_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + up_block_add_samples.pop(0)
|
||||||
|
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
|
||||||
|
if return_res_samples:
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
if up_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + up_block_add_samples.pop(0)
|
||||||
|
|
||||||
|
if return_res_samples:
|
||||||
|
return hidden_states, output_states
|
||||||
|
else:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def UpBlock2D_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||||
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
upsample_size: Optional[int] = None,
|
||||||
|
scale: float = 1.0,
|
||||||
|
return_res_samples: Optional[bool] = False,
|
||||||
|
up_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
is_freeu_enabled = (
|
||||||
|
getattr(self, "s1", None)
|
||||||
|
and getattr(self, "s2", None)
|
||||||
|
and getattr(self, "b1", None)
|
||||||
|
and getattr(self, "b2", None)
|
||||||
|
)
|
||||||
|
if return_res_samples:
|
||||||
|
output_states = ()
|
||||||
|
|
||||||
|
for resnet in self.resnets:
|
||||||
|
# pop res hidden states
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
|
||||||
|
# FreeU: Only operate on the first two stages
|
||||||
|
if is_freeu_enabled:
|
||||||
|
hidden_states, res_hidden_states = apply_freeu(
|
||||||
|
self.resolution_idx,
|
||||||
|
hidden_states,
|
||||||
|
res_hidden_states,
|
||||||
|
s1=self.s1,
|
||||||
|
s2=self.s2,
|
||||||
|
b1=self.b1,
|
||||||
|
b2=self.b2,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
if is_torch_version(">=", "1.11.0"):
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet), hidden_states, temb
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||||
|
|
||||||
|
if return_res_samples:
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
if up_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + up_block_add_samples.pop(0) # todo: add before or after
|
||||||
|
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
|
||||||
|
|
||||||
|
if return_res_samples:
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
if up_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + up_block_add_samples.pop(0) # todo: add before or after
|
||||||
|
|
||||||
|
if return_res_samples:
|
||||||
|
return hidden_states, output_states
|
||||||
|
else:
|
||||||
|
return hidden_states
|
@ -1,174 +1,29 @@
|
|||||||
# code copy from: https://github.com/parlance-zz/g-diffuser-bot
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def np_img_grey_to_rgb(data):
|
def expand_image(cv2_img, top: int, right: int, bottom: int, left: int):
|
||||||
if data.ndim == 3:
|
|
||||||
return data
|
|
||||||
return np.expand_dims(data, 2) * np.ones((1, 1, 3))
|
|
||||||
|
|
||||||
|
|
||||||
def convolve(data1, data2): # fast convolution with fft
|
|
||||||
if data1.ndim != data2.ndim: # promote to rgb if mismatch
|
|
||||||
if data1.ndim < 3:
|
|
||||||
data1 = np_img_grey_to_rgb(data1)
|
|
||||||
if data2.ndim < 3:
|
|
||||||
data2 = np_img_grey_to_rgb(data2)
|
|
||||||
return ifft2(fft2(data1) * fft2(data2))
|
|
||||||
|
|
||||||
|
|
||||||
def fft2(data):
|
|
||||||
if data.ndim > 2: # multiple channels
|
|
||||||
out_fft = np.zeros(
|
|
||||||
(data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128
|
|
||||||
)
|
|
||||||
for c in range(data.shape[2]):
|
|
||||||
c_data = data[:, :, c]
|
|
||||||
out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho")
|
|
||||||
out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
|
|
||||||
else: # single channel
|
|
||||||
out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
|
|
||||||
out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho")
|
|
||||||
out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
|
|
||||||
|
|
||||||
return out_fft
|
|
||||||
|
|
||||||
|
|
||||||
def ifft2(data):
|
|
||||||
if data.ndim > 2: # multiple channels
|
|
||||||
out_ifft = np.zeros(
|
|
||||||
(data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128
|
|
||||||
)
|
|
||||||
for c in range(data.shape[2]):
|
|
||||||
c_data = data[:, :, c]
|
|
||||||
out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho")
|
|
||||||
out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
|
|
||||||
else: # single channel
|
|
||||||
out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
|
|
||||||
out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho")
|
|
||||||
out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
|
|
||||||
|
|
||||||
return out_ifft
|
|
||||||
|
|
||||||
|
|
||||||
def get_gradient_kernel(width, height, std=3.14, mode="linear"):
|
|
||||||
window_scale_x = float(
|
|
||||||
width / min(width, height)
|
|
||||||
) # for non-square aspect ratios we still want a circular kernel
|
|
||||||
window_scale_y = float(height / min(width, height))
|
|
||||||
if mode == "gaussian":
|
|
||||||
x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
|
|
||||||
kx = np.exp(-x * x * std)
|
|
||||||
if window_scale_x != window_scale_y:
|
|
||||||
y = (np.arange(height) / height * 2.0 - 1.0) * window_scale_y
|
|
||||||
ky = np.exp(-y * y * std)
|
|
||||||
else:
|
|
||||||
y = x
|
|
||||||
ky = kx
|
|
||||||
return np.outer(kx, ky)
|
|
||||||
elif mode == "linear":
|
|
||||||
x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
|
|
||||||
if window_scale_x != window_scale_y:
|
|
||||||
y = (np.arange(height) / height * 2.0 - 1.0) * window_scale_y
|
|
||||||
else:
|
|
||||||
y = x
|
|
||||||
return np.clip(1.0 - np.sqrt(np.add.outer(x * x, y * y)) * std / 3.14, 0.0, 1.0)
|
|
||||||
else:
|
|
||||||
raise Exception("Error: Unknown mode in get_gradient_kernel: {0}".format(mode))
|
|
||||||
|
|
||||||
|
|
||||||
def image_blur(data, std=3.14, mode="linear"):
|
|
||||||
width = data.shape[0]
|
|
||||||
height = data.shape[1]
|
|
||||||
kernel = get_gradient_kernel(width, height, std, mode=mode)
|
|
||||||
return np.real(convolve(data, kernel / np.sqrt(np.sum(kernel * kernel))))
|
|
||||||
|
|
||||||
|
|
||||||
def soften_mask(mask_img, softness, space):
|
|
||||||
if softness == 0:
|
|
||||||
return mask_img
|
|
||||||
softness = min(softness, 1.0)
|
|
||||||
space = np.clip(space, 0.0, 1.0)
|
|
||||||
original_max_opacity = np.max(mask_img)
|
|
||||||
out_mask = mask_img <= 0.0
|
|
||||||
blurred_mask = image_blur(mask_img, 3.5 / softness, mode="linear")
|
|
||||||
blurred_mask = np.maximum(blurred_mask - np.max(blurred_mask[out_mask]), 0.0)
|
|
||||||
mask_img *= blurred_mask # preserve partial opacity in original input mask
|
|
||||||
mask_img /= np.max(mask_img) # renormalize
|
|
||||||
mask_img = np.clip(mask_img - space, 0.0, 1.0) # make space
|
|
||||||
mask_img /= np.max(mask_img) # and renormalize again
|
|
||||||
mask_img *= original_max_opacity # restore original max opacity
|
|
||||||
return mask_img
|
|
||||||
|
|
||||||
|
|
||||||
def expand_image(
|
|
||||||
cv2_img, top: int, right: int, bottom: int, left: int, softness: float, space: float
|
|
||||||
):
|
|
||||||
assert cv2_img.shape[2] == 3
|
assert cv2_img.shape[2] == 3
|
||||||
origin_h, origin_w = cv2_img.shape[:2]
|
origin_h, origin_w = cv2_img.shape[:2]
|
||||||
new_width = cv2_img.shape[1] + left + right
|
|
||||||
new_height = cv2_img.shape[0] + top + bottom
|
|
||||||
|
|
||||||
# TODO: which is better?
|
# TODO: which is better?
|
||||||
# new_img = np.random.randint(0, 255, (new_height, new_width, 3), np.uint8)
|
# new_img = np.ones((new_height, new_width, 3), np.uint8) * 255
|
||||||
new_img = cv2.copyMakeBorder(
|
|
||||||
cv2_img, top, bottom, left, right, cv2.BORDER_REPLICATE
|
|
||||||
)
|
|
||||||
mask_img = np.zeros((new_height, new_width), np.uint8)
|
|
||||||
mask_img[top: top + cv2_img.shape[0], left: left + cv2_img.shape[1]] = 255
|
|
||||||
|
|
||||||
if softness > 0.0:
|
|
||||||
mask_img = soften_mask(mask_img / 255.0, softness / 100.0, space / 100.0)
|
|
||||||
mask_img = (np.clip(mask_img, 0.0, 1.0) * 255.0).astype(np.uint8)
|
|
||||||
|
|
||||||
mask_image = 255.0 - mask_img # extract mask from alpha channel and invert
|
|
||||||
rgb_init_image = (
|
|
||||||
0.0 + new_img[:, :, 0:3]
|
|
||||||
) # strip mask from init_img leaving only rgb channels
|
|
||||||
|
|
||||||
hard_mask = np.zeros_like(cv2_img[:, :, 0])
|
|
||||||
if top != 0:
|
|
||||||
hard_mask[0: origin_h // 2, :] = 255
|
|
||||||
if bottom != 0:
|
|
||||||
hard_mask[origin_h // 2:, :] = 255
|
|
||||||
if left != 0:
|
|
||||||
hard_mask[:, 0: origin_w // 2] = 255
|
|
||||||
if right != 0:
|
|
||||||
hard_mask[:, origin_w // 2:] = 255
|
|
||||||
|
|
||||||
hard_mask = cv2.copyMakeBorder(
|
|
||||||
hard_mask, top, bottom, left, right, cv2.BORDER_CONSTANT, value=255
|
|
||||||
)
|
|
||||||
mask_image = np.where(hard_mask > 0, mask_image, 0)
|
|
||||||
return rgb_init_image.astype(np.uint8), mask_image.astype(np.uint8)
|
|
||||||
|
|
||||||
|
|
||||||
def expand_image2(
|
|
||||||
cv2_img, top: int, right: int, bottom: int, left: int, softness: float, space: float
|
|
||||||
):
|
|
||||||
assert cv2_img.shape[2] == 3
|
|
||||||
origin_h, origin_w = cv2_img.shape[:2]
|
|
||||||
new_width = cv2_img.shape[1] + left + right
|
|
||||||
new_height = cv2_img.shape[0] + top + bottom
|
|
||||||
|
|
||||||
# TODO: which is better?
|
|
||||||
# new_img = np.random.randint(0, 255, (new_height, new_width, 3), np.uint8)
|
|
||||||
new_img = cv2.copyMakeBorder(
|
new_img = cv2.copyMakeBorder(
|
||||||
cv2_img, top, bottom, left, right, cv2.BORDER_REPLICATE
|
cv2_img, top, bottom, left, right, cv2.BORDER_REPLICATE
|
||||||
)
|
)
|
||||||
|
|
||||||
inner_padding_left = 13 if left > 0 else 0
|
inner_padding_left = 0 if left > 0 else 0
|
||||||
inner_padding_right = 13 if right > 0 else 0
|
inner_padding_right = 0 if right > 0 else 0
|
||||||
inner_padding_top = 13 if top > 0 else 0
|
inner_padding_top = 0 if top > 0 else 0
|
||||||
inner_padding_bottom = 13 if bottom > 0 else 0
|
inner_padding_bottom = 0 if bottom > 0 else 0
|
||||||
|
|
||||||
mask_image = np.zeros(
|
mask_image = np.zeros(
|
||||||
(
|
(
|
||||||
origin_h - inner_padding_top - inner_padding_bottom
|
origin_h - inner_padding_top - inner_padding_bottom,
|
||||||
, origin_w - inner_padding_left - inner_padding_right
|
origin_w - inner_padding_left - inner_padding_right,
|
||||||
),
|
),
|
||||||
np.uint8)
|
np.uint8,
|
||||||
|
)
|
||||||
mask_image = cv2.copyMakeBorder(
|
mask_image = cv2.copyMakeBorder(
|
||||||
mask_image,
|
mask_image,
|
||||||
top + inner_padding_top,
|
top + inner_padding_top,
|
||||||
@ -176,11 +31,11 @@ def expand_image2(
|
|||||||
left + inner_padding_left,
|
left + inner_padding_left,
|
||||||
right + inner_padding_right,
|
right + inner_padding_right,
|
||||||
cv2.BORDER_CONSTANT,
|
cv2.BORDER_CONSTANT,
|
||||||
value=255
|
value=255,
|
||||||
)
|
)
|
||||||
# k = 2*int(min(origin_h, origin_w) // 6)+1
|
# k = 2*int(min(origin_h, origin_w) // 6)+1
|
||||||
k = 7
|
# k = 7
|
||||||
mask_image = cv2.GaussianBlur(mask_image, (k, k), 0)
|
# mask_image = cv2.GaussianBlur(mask_image, (k, k), 0)
|
||||||
return new_img, mask_image
|
return new_img, mask_image
|
||||||
|
|
||||||
|
|
||||||
@ -190,7 +45,7 @@ if __name__ == "__main__":
|
|||||||
current_dir = Path(__file__).parent.absolute().resolve()
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
image_path = "/Users/cwq/code/github/IOPaint/iopaint/tests/bunny.jpeg"
|
image_path = "/Users/cwq/code/github/IOPaint/iopaint/tests/bunny.jpeg"
|
||||||
init_image = cv2.imread(str(image_path))
|
init_image = cv2.imread(str(image_path))
|
||||||
init_image, mask_image = expand_image2(
|
init_image, mask_image = expand_image(
|
||||||
init_image,
|
init_image,
|
||||||
top=0,
|
top=0,
|
||||||
right=0,
|
right=0,
|
||||||
|
File diff suppressed because it is too large
Load Diff
186
iopaint/model/power_paint/power_paint_v2.py
Normal file
186
iopaint/model/power_paint/power_paint_v2.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
from itertools import chain
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
|
import cv2
|
||||||
|
import torch
|
||||||
|
from iopaint.model.original_sd_configs import get_config_files
|
||||||
|
from loguru import logger
|
||||||
|
from transformers import CLIPTextModel, CLIPTokenizer
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ..base import DiffusionInpaintModel
|
||||||
|
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||||
|
from ..utils import (
|
||||||
|
get_torch_dtype,
|
||||||
|
enable_low_mem,
|
||||||
|
is_local_files_only,
|
||||||
|
handle_from_pretrained_exceptions,
|
||||||
|
)
|
||||||
|
from .powerpaint_tokenizer import task_to_prompt
|
||||||
|
from iopaint.schema import InpaintRequest, ModelType
|
||||||
|
from .v2.BrushNet_CA import BrushNetModel
|
||||||
|
from .v2.unet_2d_condition import UNet2DConditionModel_forward
|
||||||
|
from .v2.unet_2d_blocks import (
|
||||||
|
CrossAttnDownBlock2D_forward,
|
||||||
|
DownBlock2D_forward,
|
||||||
|
CrossAttnUpBlock2D_forward,
|
||||||
|
UpBlock2D_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PowerPaintV2(DiffusionInpaintModel):
|
||||||
|
pad_mod = 8
|
||||||
|
min_size = 512
|
||||||
|
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
|
||||||
|
hf_model_id = "Sanster/PowerPaint_v2"
|
||||||
|
|
||||||
|
def init_model(self, device: torch.device, **kwargs):
|
||||||
|
from .v2.pipeline_PowerPaint_Brushnet_CA import (
|
||||||
|
StableDiffusionPowerPaintBrushNetPipeline,
|
||||||
|
)
|
||||||
|
from .powerpaint_tokenizer import PowerPaintTokenizer
|
||||||
|
|
||||||
|
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||||
|
model_kwargs = {"local_files_only": is_local_files_only(**kwargs)}
|
||||||
|
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||||
|
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||||
|
model_kwargs.update(
|
||||||
|
dict(
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
text_encoder_brushnet = CLIPTextModel.from_pretrained(
|
||||||
|
self.hf_model_id,
|
||||||
|
subfolder="text_encoder_brushnet",
|
||||||
|
variant="fp16",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
local_files_only=model_kwargs["local_files_only"],
|
||||||
|
)
|
||||||
|
|
||||||
|
brushnet = BrushNetModel.from_pretrained(
|
||||||
|
self.hf_model_id,
|
||||||
|
subfolder="PowerPaint_Brushnet",
|
||||||
|
variant="fp16",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
local_files_only=model_kwargs["local_files_only"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.model_info.is_single_file_diffusers:
|
||||||
|
if self.model_info.model_type == ModelType.DIFFUSERS_SD:
|
||||||
|
model_kwargs["num_in_channels"] = 4
|
||||||
|
else:
|
||||||
|
model_kwargs["num_in_channels"] = 9
|
||||||
|
|
||||||
|
pipe = StableDiffusionPowerPaintBrushNetPipeline.from_single_file(
|
||||||
|
self.model_id_or_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
load_safety_checker=False,
|
||||||
|
original_config_file=get_config_files()["v1"],
|
||||||
|
brushnet=brushnet,
|
||||||
|
text_encoder_brushnet=text_encoder_brushnet,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pipe = handle_from_pretrained_exceptions(
|
||||||
|
StableDiffusionPowerPaintBrushNetPipeline.from_pretrained,
|
||||||
|
pretrained_model_name_or_path=self.model_id_or_path,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
brushnet=brushnet,
|
||||||
|
text_encoder_brushnet=text_encoder_brushnet,
|
||||||
|
variant="fp16",
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
pipe.tokenizer = PowerPaintTokenizer(
|
||||||
|
CLIPTokenizer.from_pretrained(self.hf_model_id, subfolder="tokenizer")
|
||||||
|
)
|
||||||
|
self.model = pipe
|
||||||
|
|
||||||
|
enable_low_mem(self.model, kwargs.get("low_mem", False))
|
||||||
|
|
||||||
|
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||||
|
logger.info("Enable sequential cpu offload")
|
||||||
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
|
else:
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
if kwargs["sd_cpu_textencoder"]:
|
||||||
|
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||||
|
self.model.text_encoder = CPUTextEncoderWrapper(
|
||||||
|
self.model.text_encoder, torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
|
# Monkey patch the forward method of the UNet to use the brushnet_unet_forward method
|
||||||
|
self.model.unet.forward = UNet2DConditionModel_forward.__get__(
|
||||||
|
self.model.unet, self.model.unet.__class__
|
||||||
|
)
|
||||||
|
|
||||||
|
# Monkey patch unet down_blocks to use CrossAttnDownBlock2D_forward
|
||||||
|
for down_block in chain(
|
||||||
|
self.model.unet.down_blocks, self.model.brushnet.down_blocks
|
||||||
|
):
|
||||||
|
if down_block.__class__.__name__ == "CrossAttnDownBlock2D":
|
||||||
|
down_block.forward = CrossAttnDownBlock2D_forward.__get__(
|
||||||
|
down_block, down_block.__class__
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
down_block.forward = DownBlock2D_forward.__get__(
|
||||||
|
down_block, down_block.__class__
|
||||||
|
)
|
||||||
|
|
||||||
|
for up_block in chain(self.model.unet.up_blocks, self.model.brushnet.up_blocks):
|
||||||
|
if up_block.__class__.__name__ == "CrossAttnUpBlock2D":
|
||||||
|
up_block.forward = CrossAttnUpBlock2D_forward.__get__(
|
||||||
|
up_block, up_block.__class__
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
up_block.forward = UpBlock2D_forward.__get__(
|
||||||
|
up_block, up_block.__class__
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, image, mask, config: InpaintRequest):
|
||||||
|
"""Input image and output image have same size
|
||||||
|
image: [H, W, C] RGB
|
||||||
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
|
return: BGR IMAGE
|
||||||
|
"""
|
||||||
|
self.set_scheduler(config)
|
||||||
|
|
||||||
|
image = image * (1 - mask / 255.0)
|
||||||
|
img_h, img_w = image.shape[:2]
|
||||||
|
|
||||||
|
image = PIL.Image.fromarray(image.astype(np.uint8))
|
||||||
|
mask = PIL.Image.fromarray(mask[:, :, -1], mode="L").convert("RGB")
|
||||||
|
|
||||||
|
promptA, promptB, negative_promptA, negative_promptB = task_to_prompt(
|
||||||
|
config.powerpaint_task
|
||||||
|
)
|
||||||
|
|
||||||
|
output = self.model(
|
||||||
|
image=image,
|
||||||
|
mask=mask,
|
||||||
|
promptA=promptA,
|
||||||
|
promptB=promptB,
|
||||||
|
promptU=config.prompt,
|
||||||
|
tradoff=config.fitting_degree,
|
||||||
|
tradoff_nag=config.fitting_degree,
|
||||||
|
negative_promptA=negative_promptA,
|
||||||
|
negative_promptB=negative_promptB,
|
||||||
|
negative_promptU=config.negative_prompt,
|
||||||
|
num_inference_steps=config.sd_steps,
|
||||||
|
# strength=config.sd_strength,
|
||||||
|
brushnet_conditioning_scale=1.0,
|
||||||
|
guidance_scale=config.sd_guidance_scale,
|
||||||
|
output_type="np",
|
||||||
|
callback_on_step_end=self.callback,
|
||||||
|
height=img_h,
|
||||||
|
width=img_w,
|
||||||
|
generator=torch.manual_seed(config.sd_seed),
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
output = (output * 255).round().astype("uint8")
|
||||||
|
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||||
|
return output
|
@ -1,8 +1,6 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import copy
|
import copy
|
||||||
import random
|
import random
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Any, List, Union
|
||||||
from transformers import CLIPTokenizer
|
from transformers import CLIPTokenizer
|
||||||
|
|
||||||
from iopaint.schema import PowerPaintTask
|
from iopaint.schema import PowerPaintTask
|
||||||
@ -14,6 +12,11 @@ def add_task_to_prompt(prompt, negative_prompt, task: PowerPaintTask):
|
|||||||
promptB = prompt + " P_ctxt"
|
promptB = prompt + " P_ctxt"
|
||||||
negative_promptA = negative_prompt + " P_obj"
|
negative_promptA = negative_prompt + " P_obj"
|
||||||
negative_promptB = negative_prompt + " P_obj"
|
negative_promptB = negative_prompt + " P_obj"
|
||||||
|
elif task == PowerPaintTask.context_aware:
|
||||||
|
promptA = prompt + " P_ctxt"
|
||||||
|
promptB = prompt + " P_ctxt"
|
||||||
|
negative_promptA = negative_prompt
|
||||||
|
negative_promptB = negative_prompt
|
||||||
elif task == PowerPaintTask.shape_guided:
|
elif task == PowerPaintTask.shape_guided:
|
||||||
promptA = prompt + " P_shape"
|
promptA = prompt + " P_shape"
|
||||||
promptB = prompt + " P_ctxt"
|
promptB = prompt + " P_ctxt"
|
||||||
@ -33,6 +36,18 @@ def add_task_to_prompt(prompt, negative_prompt, task: PowerPaintTask):
|
|||||||
return promptA, promptB, negative_promptA, negative_promptB
|
return promptA, promptB, negative_promptA, negative_promptB
|
||||||
|
|
||||||
|
|
||||||
|
def task_to_prompt(task: PowerPaintTask):
|
||||||
|
promptA, promptB, negative_promptA, negative_promptB = add_task_to_prompt(
|
||||||
|
"", "", task
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
promptA.strip(),
|
||||||
|
promptB.strip(),
|
||||||
|
negative_promptA.strip(),
|
||||||
|
negative_promptB.strip(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PowerPaintTokenizer:
|
class PowerPaintTokenizer:
|
||||||
def __init__(self, tokenizer: CLIPTokenizer):
|
def __init__(self, tokenizer: CLIPTokenizer):
|
||||||
self.wrapped = tokenizer
|
self.wrapped = tokenizer
|
||||||
@ -237,304 +252,3 @@ class PowerPaintTokenizer:
|
|||||||
return text
|
return text
|
||||||
replaced_text = self.replace_text_with_placeholder_tokens(text)
|
replaced_text = self.replace_text_with_placeholder_tokens(text)
|
||||||
return replaced_text
|
return replaced_text
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingLayerWithFixes(nn.Module):
|
|
||||||
"""The revised embedding layer to support external embeddings. This design
|
|
||||||
of this class is inspired by https://github.com/AUTOMATIC1111/stable-
|
|
||||||
diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi
|
|
||||||
jack.py#L224 # noqa.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
wrapped (nn.Emebdding): The embedding layer to be wrapped.
|
|
||||||
external_embeddings (Union[dict, List[dict]], optional): The external
|
|
||||||
embeddings added to this layer. Defaults to None.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
wrapped: nn.Embedding,
|
|
||||||
external_embeddings: Optional[Union[dict, List[dict]]] = None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.wrapped = wrapped
|
|
||||||
self.num_embeddings = wrapped.weight.shape[0]
|
|
||||||
|
|
||||||
self.external_embeddings = []
|
|
||||||
if external_embeddings:
|
|
||||||
self.add_embeddings(external_embeddings)
|
|
||||||
|
|
||||||
self.trainable_embeddings = nn.ParameterDict()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def weight(self):
|
|
||||||
"""Get the weight of wrapped embedding layer."""
|
|
||||||
return self.wrapped.weight
|
|
||||||
|
|
||||||
def check_duplicate_names(self, embeddings: List[dict]):
|
|
||||||
"""Check whether duplicate names exist in list of 'external
|
|
||||||
embeddings'.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embeddings (List[dict]): A list of embedding to be check.
|
|
||||||
"""
|
|
||||||
names = [emb["name"] for emb in embeddings]
|
|
||||||
assert len(names) == len(set(names)), (
|
|
||||||
"Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'"
|
|
||||||
)
|
|
||||||
|
|
||||||
def check_ids_overlap(self, embeddings):
|
|
||||||
"""Check whether overlap exist in token ids of 'external_embeddings'.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embeddings (List[dict]): A list of embedding to be check.
|
|
||||||
"""
|
|
||||||
ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings]
|
|
||||||
ids_range.sort() # sort by 'start'
|
|
||||||
# check if 'end' has overlapping
|
|
||||||
for idx in range(len(ids_range) - 1):
|
|
||||||
name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1]
|
|
||||||
assert ids_range[idx][1] <= ids_range[idx + 1][0], (
|
|
||||||
f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'."
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
|
|
||||||
"""Add external embeddings to this layer.
|
|
||||||
|
|
||||||
Use case:
|
|
||||||
|
|
||||||
>>> 1. Add token to tokenizer and get the token id.
|
|
||||||
>>> tokenizer = TokenizerWrapper('openai/clip-vit-base-patch32')
|
|
||||||
>>> # 'how much' in kiswahili
|
|
||||||
>>> tokenizer.add_placeholder_tokens('ngapi', num_vec_per_token=4)
|
|
||||||
>>>
|
|
||||||
>>> 2. Add external embeddings to the model.
|
|
||||||
>>> new_embedding = {
|
|
||||||
>>> 'name': 'ngapi', # 'how much' in kiswahili
|
|
||||||
>>> 'embedding': torch.ones(1, 15) * 4,
|
|
||||||
>>> 'start': tokenizer.get_token_info('kwaheri')['start'],
|
|
||||||
>>> 'end': tokenizer.get_token_info('kwaheri')['end'],
|
|
||||||
>>> 'trainable': False # if True, will registry as a parameter
|
|
||||||
>>> }
|
|
||||||
>>> embedding_layer = nn.Embedding(10, 15)
|
|
||||||
>>> embedding_layer_wrapper = EmbeddingLayerWithFixes(embedding_layer)
|
|
||||||
>>> embedding_layer_wrapper.add_embeddings(new_embedding)
|
|
||||||
>>>
|
|
||||||
>>> 3. Forward tokenizer and embedding layer!
|
|
||||||
>>> input_text = ['hello, ngapi!', 'hello my friend, ngapi?']
|
|
||||||
>>> input_ids = tokenizer(
|
|
||||||
>>> input_text, padding='max_length', truncation=True,
|
|
||||||
>>> return_tensors='pt')['input_ids']
|
|
||||||
>>> out_feat = embedding_layer_wrapper(input_ids)
|
|
||||||
>>>
|
|
||||||
>>> 4. Let's validate the result!
|
|
||||||
>>> assert (out_feat[0, 3: 7] == 2.3).all()
|
|
||||||
>>> assert (out_feat[2, 5: 9] == 2.3).all()
|
|
||||||
|
|
||||||
Args:
|
|
||||||
embeddings (Union[dict, list[dict]]): The external embeddings to
|
|
||||||
be added. Each dict must contain the following 4 fields: 'name'
|
|
||||||
(the name of this embedding), 'embedding' (the embedding
|
|
||||||
tensor), 'start' (the start token id of this embedding), 'end'
|
|
||||||
(the end token id of this embedding). For example:
|
|
||||||
`{name: NAME, start: START, end: END, embedding: torch.Tensor}`
|
|
||||||
"""
|
|
||||||
if isinstance(embeddings, dict):
|
|
||||||
embeddings = [embeddings]
|
|
||||||
|
|
||||||
self.external_embeddings += embeddings
|
|
||||||
self.check_duplicate_names(self.external_embeddings)
|
|
||||||
self.check_ids_overlap(self.external_embeddings)
|
|
||||||
|
|
||||||
# set for trainable
|
|
||||||
added_trainable_emb_info = []
|
|
||||||
for embedding in embeddings:
|
|
||||||
trainable = embedding.get("trainable", False)
|
|
||||||
if trainable:
|
|
||||||
name = embedding["name"]
|
|
||||||
embedding["embedding"] = torch.nn.Parameter(embedding["embedding"])
|
|
||||||
self.trainable_embeddings[name] = embedding["embedding"]
|
|
||||||
added_trainable_emb_info.append(name)
|
|
||||||
|
|
||||||
added_emb_info = [emb["name"] for emb in embeddings]
|
|
||||||
added_emb_info = ", ".join(added_emb_info)
|
|
||||||
print(f"Successfully add external embeddings: {added_emb_info}.", "current")
|
|
||||||
|
|
||||||
if added_trainable_emb_info:
|
|
||||||
added_trainable_emb_info = ", ".join(added_trainable_emb_info)
|
|
||||||
print(
|
|
||||||
"Successfully add trainable external embeddings: "
|
|
||||||
f"{added_trainable_emb_info}",
|
|
||||||
"current",
|
|
||||||
)
|
|
||||||
|
|
||||||
def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""Replace external input ids to 0.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids (torch.Tensor): The input ids to be replaced.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The replaced input ids.
|
|
||||||
"""
|
|
||||||
input_ids_fwd = input_ids.clone()
|
|
||||||
input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0
|
|
||||||
return input_ids_fwd
|
|
||||||
|
|
||||||
def replace_embeddings(
|
|
||||||
self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""Replace external embedding to the embedding layer. Noted that, in
|
|
||||||
this function we use `torch.cat` to avoid inplace modification.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids (torch.Tensor): The original token ids. Shape like
|
|
||||||
[LENGTH, ].
|
|
||||||
embedding (torch.Tensor): The embedding of token ids after
|
|
||||||
`replace_input_ids` function.
|
|
||||||
external_embedding (dict): The external embedding to be replaced.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: The replaced embedding.
|
|
||||||
"""
|
|
||||||
new_embedding = []
|
|
||||||
|
|
||||||
name = external_embedding["name"]
|
|
||||||
start = external_embedding["start"]
|
|
||||||
end = external_embedding["end"]
|
|
||||||
target_ids_to_replace = [i for i in range(start, end)]
|
|
||||||
ext_emb = external_embedding["embedding"]
|
|
||||||
|
|
||||||
# do not need to replace
|
|
||||||
if not (input_ids == start).any():
|
|
||||||
return embedding
|
|
||||||
|
|
||||||
# start replace
|
|
||||||
s_idx, e_idx = 0, 0
|
|
||||||
while e_idx < len(input_ids):
|
|
||||||
if input_ids[e_idx] == start:
|
|
||||||
if e_idx != 0:
|
|
||||||
# add embedding do not need to replace
|
|
||||||
new_embedding.append(embedding[s_idx:e_idx])
|
|
||||||
|
|
||||||
# check if the next embedding need to replace is valid
|
|
||||||
actually_ids_to_replace = [
|
|
||||||
int(i) for i in input_ids[e_idx : e_idx + end - start]
|
|
||||||
]
|
|
||||||
assert actually_ids_to_replace == target_ids_to_replace, (
|
|
||||||
f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. "
|
|
||||||
f"Expect '{target_ids_to_replace}' for embedding "
|
|
||||||
f"'{name}' but found '{actually_ids_to_replace}'."
|
|
||||||
)
|
|
||||||
|
|
||||||
new_embedding.append(ext_emb)
|
|
||||||
|
|
||||||
s_idx = e_idx + end - start
|
|
||||||
e_idx = s_idx + 1
|
|
||||||
else:
|
|
||||||
e_idx += 1
|
|
||||||
|
|
||||||
if e_idx == len(input_ids):
|
|
||||||
new_embedding.append(embedding[s_idx:e_idx])
|
|
||||||
|
|
||||||
return torch.cat(new_embedding, dim=0)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None
|
|
||||||
):
|
|
||||||
"""The forward function.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or
|
|
||||||
[LENGTH, ].
|
|
||||||
external_embeddings (Optional[List[dict]]): The external
|
|
||||||
embeddings. If not passed, only `self.external_embeddings`
|
|
||||||
will be used. Defaults to None.
|
|
||||||
|
|
||||||
input_ids: shape like [bz, LENGTH] or [LENGTH].
|
|
||||||
"""
|
|
||||||
assert input_ids.ndim in [1, 2]
|
|
||||||
if input_ids.ndim == 1:
|
|
||||||
input_ids = input_ids.unsqueeze(0)
|
|
||||||
|
|
||||||
if external_embeddings is None and not self.external_embeddings:
|
|
||||||
return self.wrapped(input_ids)
|
|
||||||
|
|
||||||
input_ids_fwd = self.replace_input_ids(input_ids)
|
|
||||||
inputs_embeds = self.wrapped(input_ids_fwd)
|
|
||||||
|
|
||||||
vecs = []
|
|
||||||
|
|
||||||
if external_embeddings is None:
|
|
||||||
external_embeddings = []
|
|
||||||
elif isinstance(external_embeddings, dict):
|
|
||||||
external_embeddings = [external_embeddings]
|
|
||||||
embeddings = self.external_embeddings + external_embeddings
|
|
||||||
|
|
||||||
for input_id, embedding in zip(input_ids, inputs_embeds):
|
|
||||||
new_embedding = embedding
|
|
||||||
for external_embedding in embeddings:
|
|
||||||
new_embedding = self.replace_embeddings(
|
|
||||||
input_id, new_embedding, external_embedding
|
|
||||||
)
|
|
||||||
vecs.append(new_embedding)
|
|
||||||
|
|
||||||
return torch.stack(vecs)
|
|
||||||
|
|
||||||
|
|
||||||
def add_tokens(
|
|
||||||
tokenizer,
|
|
||||||
text_encoder,
|
|
||||||
placeholder_tokens: list,
|
|
||||||
initialize_tokens: list = None,
|
|
||||||
num_vectors_per_token: int = 1,
|
|
||||||
):
|
|
||||||
"""Add token for training.
|
|
||||||
|
|
||||||
# TODO: support add tokens as dict, then we can load pretrained tokens.
|
|
||||||
"""
|
|
||||||
if initialize_tokens is not None:
|
|
||||||
assert len(initialize_tokens) == len(
|
|
||||||
placeholder_tokens
|
|
||||||
), "placeholder_token should be the same length as initialize_token"
|
|
||||||
for ii in range(len(placeholder_tokens)):
|
|
||||||
tokenizer.add_placeholder_token(
|
|
||||||
placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token
|
|
||||||
)
|
|
||||||
|
|
||||||
# text_encoder.set_embedding_layer()
|
|
||||||
embedding_layer = text_encoder.text_model.embeddings.token_embedding
|
|
||||||
text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(
|
|
||||||
embedding_layer
|
|
||||||
)
|
|
||||||
embedding_layer = text_encoder.text_model.embeddings.token_embedding
|
|
||||||
|
|
||||||
assert embedding_layer is not None, (
|
|
||||||
"Do not support get embedding layer for current text encoder. "
|
|
||||||
"Please check your configuration."
|
|
||||||
)
|
|
||||||
initialize_embedding = []
|
|
||||||
if initialize_tokens is not None:
|
|
||||||
for ii in range(len(placeholder_tokens)):
|
|
||||||
init_id = tokenizer(initialize_tokens[ii]).input_ids[1]
|
|
||||||
temp_embedding = embedding_layer.weight[init_id]
|
|
||||||
initialize_embedding.append(
|
|
||||||
temp_embedding[None, ...].repeat(num_vectors_per_token, 1)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
for ii in range(len(placeholder_tokens)):
|
|
||||||
init_id = tokenizer("a").input_ids[1]
|
|
||||||
temp_embedding = embedding_layer.weight[init_id]
|
|
||||||
len_emb = temp_embedding.shape[0]
|
|
||||||
init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0
|
|
||||||
initialize_embedding.append(init_weight)
|
|
||||||
|
|
||||||
# initialize_embedding = torch.cat(initialize_embedding,dim=0)
|
|
||||||
|
|
||||||
token_info_all = []
|
|
||||||
for ii in range(len(placeholder_tokens)):
|
|
||||||
token_info = tokenizer.get_token_info(placeholder_tokens[ii])
|
|
||||||
token_info["embedding"] = initialize_embedding[ii]
|
|
||||||
token_info["trainable"] = True
|
|
||||||
token_info_all.append(token_info)
|
|
||||||
embedding_layer.add_embeddings(token_info_all)
|
|
||||||
|
1094
iopaint/model/power_paint/v2/BrushNet_CA.py
Normal file
1094
iopaint/model/power_paint/v2/BrushNet_CA.py
Normal file
File diff suppressed because it is too large
Load Diff
1690
iopaint/model/power_paint/v2/pipeline_PowerPaint_Brushnet_CA.py
Normal file
1690
iopaint/model/power_paint/v2/pipeline_PowerPaint_Brushnet_CA.py
Normal file
File diff suppressed because it is too large
Load Diff
342
iopaint/model/power_paint/v2/unet_2d_blocks.py
Normal file
342
iopaint/model/power_paint/v2/unet_2d_blocks.py
Normal file
@ -0,0 +1,342 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from diffusers.utils import is_torch_version, logging
|
||||||
|
from diffusers.utils.torch_utils import apply_freeu
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def CrossAttnDownBlock2D_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
additional_residuals: Optional[torch.FloatTensor] = None,
|
||||||
|
down_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||||
|
output_states = ()
|
||||||
|
|
||||||
|
lora_scale = (
|
||||||
|
cross_attention_kwargs.get("scale", 1.0)
|
||||||
|
if cross_attention_kwargs is not None
|
||||||
|
else 1.0
|
||||||
|
)
|
||||||
|
|
||||||
|
blocks = list(zip(self.resnets, self.attentions))
|
||||||
|
|
||||||
|
for i, (resnet, attn) in enumerate(blocks):
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module, return_dict=None):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
if return_dict is not None:
|
||||||
|
return module(*inputs, return_dict=return_dict)
|
||||||
|
else:
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
ckpt_kwargs: Dict[str, Any] = (
|
||||||
|
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||||
|
)
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet),
|
||||||
|
hidden_states,
|
||||||
|
temb,
|
||||||
|
**ckpt_kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = attn(
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||||
|
hidden_states = attn(
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
# apply additional residuals to the output of the last pair of resnet and attention blocks
|
||||||
|
if i == len(blocks) - 1 and additional_residuals is not None:
|
||||||
|
hidden_states = hidden_states + additional_residuals
|
||||||
|
|
||||||
|
if down_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + down_block_add_samples.pop(0)
|
||||||
|
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states, scale=lora_scale)
|
||||||
|
|
||||||
|
if down_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + down_block_add_samples.pop(
|
||||||
|
0
|
||||||
|
) # todo: add before or after
|
||||||
|
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
|
||||||
|
return hidden_states, output_states
|
||||||
|
|
||||||
|
|
||||||
|
def DownBlock2D_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
scale: float = 1.0,
|
||||||
|
down_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
|
||||||
|
output_states = ()
|
||||||
|
|
||||||
|
for resnet in self.resnets:
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
if is_torch_version(">=", "1.11.0"):
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet),
|
||||||
|
hidden_states,
|
||||||
|
temb,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet), hidden_states, temb
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||||
|
|
||||||
|
if down_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + down_block_add_samples.pop(0)
|
||||||
|
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
|
||||||
|
if self.downsamplers is not None:
|
||||||
|
for downsampler in self.downsamplers:
|
||||||
|
hidden_states = downsampler(hidden_states, scale=scale)
|
||||||
|
|
||||||
|
if down_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + down_block_add_samples.pop(
|
||||||
|
0
|
||||||
|
) # todo: add before or after
|
||||||
|
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
|
||||||
|
return hidden_states, output_states
|
||||||
|
|
||||||
|
|
||||||
|
def CrossAttnUpBlock2D_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||||
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
upsample_size: Optional[int] = None,
|
||||||
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
|
return_res_samples: Optional[bool] = False,
|
||||||
|
up_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
lora_scale = (
|
||||||
|
cross_attention_kwargs.get("scale", 1.0)
|
||||||
|
if cross_attention_kwargs is not None
|
||||||
|
else 1.0
|
||||||
|
)
|
||||||
|
is_freeu_enabled = (
|
||||||
|
getattr(self, "s1", None)
|
||||||
|
and getattr(self, "s2", None)
|
||||||
|
and getattr(self, "b1", None)
|
||||||
|
and getattr(self, "b2", None)
|
||||||
|
)
|
||||||
|
if return_res_samples:
|
||||||
|
output_states = ()
|
||||||
|
|
||||||
|
for resnet, attn in zip(self.resnets, self.attentions):
|
||||||
|
# pop res hidden states
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
|
||||||
|
# FreeU: Only operate on the first two stages
|
||||||
|
if is_freeu_enabled:
|
||||||
|
hidden_states, res_hidden_states = apply_freeu(
|
||||||
|
self.resolution_idx,
|
||||||
|
hidden_states,
|
||||||
|
res_hidden_states,
|
||||||
|
s1=self.s1,
|
||||||
|
s2=self.s2,
|
||||||
|
b1=self.b1,
|
||||||
|
b2=self.b2,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module, return_dict=None):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
if return_dict is not None:
|
||||||
|
return module(*inputs, return_dict=return_dict)
|
||||||
|
else:
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
ckpt_kwargs: Dict[str, Any] = (
|
||||||
|
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||||
|
)
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet),
|
||||||
|
hidden_states,
|
||||||
|
temb,
|
||||||
|
**ckpt_kwargs,
|
||||||
|
)
|
||||||
|
hidden_states = attn(
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb, scale=lora_scale)
|
||||||
|
hidden_states = attn(
|
||||||
|
hidden_states,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
|
if return_res_samples:
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
if up_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + up_block_add_samples.pop(0)
|
||||||
|
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
|
||||||
|
if return_res_samples:
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
if up_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + up_block_add_samples.pop(0)
|
||||||
|
|
||||||
|
if return_res_samples:
|
||||||
|
return hidden_states, output_states
|
||||||
|
else:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def UpBlock2D_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
||||||
|
temb: Optional[torch.FloatTensor] = None,
|
||||||
|
upsample_size: Optional[int] = None,
|
||||||
|
scale: float = 1.0,
|
||||||
|
return_res_samples: Optional[bool] = False,
|
||||||
|
up_block_add_samples: Optional[torch.FloatTensor] = None,
|
||||||
|
) -> torch.FloatTensor:
|
||||||
|
is_freeu_enabled = (
|
||||||
|
getattr(self, "s1", None)
|
||||||
|
and getattr(self, "s2", None)
|
||||||
|
and getattr(self, "b1", None)
|
||||||
|
and getattr(self, "b2", None)
|
||||||
|
)
|
||||||
|
if return_res_samples:
|
||||||
|
output_states = ()
|
||||||
|
|
||||||
|
for resnet in self.resnets:
|
||||||
|
# pop res hidden states
|
||||||
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
|
|
||||||
|
# FreeU: Only operate on the first two stages
|
||||||
|
if is_freeu_enabled:
|
||||||
|
hidden_states, res_hidden_states = apply_freeu(
|
||||||
|
self.resolution_idx,
|
||||||
|
hidden_states,
|
||||||
|
res_hidden_states,
|
||||||
|
s1=self.s1,
|
||||||
|
s2=self.s2,
|
||||||
|
b1=self.b1,
|
||||||
|
b2=self.b2,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
if is_torch_version(">=", "1.11.0"):
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet),
|
||||||
|
hidden_states,
|
||||||
|
temb,
|
||||||
|
use_reentrant=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(resnet), hidden_states, temb
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb, scale=scale)
|
||||||
|
|
||||||
|
if return_res_samples:
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
if up_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + up_block_add_samples.pop(
|
||||||
|
0
|
||||||
|
) # todo: add before or after
|
||||||
|
|
||||||
|
if self.upsamplers is not None:
|
||||||
|
for upsampler in self.upsamplers:
|
||||||
|
hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
|
||||||
|
|
||||||
|
if return_res_samples:
|
||||||
|
output_states = output_states + (hidden_states,)
|
||||||
|
if up_block_add_samples is not None:
|
||||||
|
hidden_states = hidden_states + up_block_add_samples.pop(
|
||||||
|
0
|
||||||
|
) # todo: add before or after
|
||||||
|
|
||||||
|
if return_res_samples:
|
||||||
|
return hidden_states, output_states
|
||||||
|
else:
|
||||||
|
return hidden_states
|
402
iopaint/model/power_paint/v2/unet_2d_condition.py
Normal file
402
iopaint/model/power_paint/v2/unet_2d_condition.py
Normal file
@ -0,0 +1,402 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from typing import Any, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
||||||
|
from diffusers.utils import (
|
||||||
|
USE_PEFT_BACKEND,
|
||||||
|
deprecate,
|
||||||
|
logging,
|
||||||
|
scale_lora_layers,
|
||||||
|
unscale_lora_layers,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def UNet2DConditionModel_forward(
|
||||||
|
self,
|
||||||
|
sample: torch.FloatTensor,
|
||||||
|
timestep: Union[torch.Tensor, float, int],
|
||||||
|
encoder_hidden_states: torch.Tensor,
|
||||||
|
class_labels: Optional[torch.Tensor] = None,
|
||||||
|
timestep_cond: Optional[torch.Tensor] = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
||||||
|
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
return_dict: bool = True,
|
||||||
|
down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||||
|
r"""
|
||||||
|
The [`UNet2DConditionModel`] forward method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sample (`torch.FloatTensor`):
|
||||||
|
The noisy input tensor with the following shape `(batch, channel, height, width)`.
|
||||||
|
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
|
||||||
|
encoder_hidden_states (`torch.FloatTensor`):
|
||||||
|
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
|
||||||
|
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
|
||||||
|
timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
|
||||||
|
through the `self.time_embedding` layer to obtain the timestep embeddings.
|
||||||
|
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
|
||||||
|
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
||||||
|
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
||||||
|
negative values to the attention scores corresponding to "discard" tokens.
|
||||||
|
cross_attention_kwargs (`dict`, *optional*):
|
||||||
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||||
|
`self.processor` in
|
||||||
|
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||||
|
added_cond_kwargs: (`dict`, *optional*):
|
||||||
|
A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
|
||||||
|
are passed along to the UNet blocks.
|
||||||
|
down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
|
||||||
|
A tuple of tensors that if specified are added to the residuals of down unet blocks.
|
||||||
|
mid_block_additional_residual: (`torch.Tensor`, *optional*):
|
||||||
|
A tensor that if specified is added to the residual of the middle unet block.
|
||||||
|
encoder_attention_mask (`torch.Tensor`):
|
||||||
|
A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
|
||||||
|
`True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
|
||||||
|
which adds large negative values to the attention scores corresponding to "discard" tokens.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
||||||
|
tuple.
|
||||||
|
cross_attention_kwargs (`dict`, *optional*):
|
||||||
|
A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
|
||||||
|
added_cond_kwargs: (`dict`, *optional*):
|
||||||
|
A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
|
||||||
|
are passed along to the UNet blocks.
|
||||||
|
down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||||
|
additional residuals to be added to UNet long skip connections from down blocks to up blocks for
|
||||||
|
example from ControlNet side model(s)
|
||||||
|
mid_block_additional_residual (`torch.Tensor`, *optional*):
|
||||||
|
additional residual to be added to UNet mid block output, for example from ControlNet side model
|
||||||
|
down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
|
||||||
|
additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
|
||||||
|
If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
|
||||||
|
a `tuple` is returned where the first element is the sample tensor.
|
||||||
|
"""
|
||||||
|
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
||||||
|
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
|
||||||
|
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
||||||
|
# on the fly if necessary.
|
||||||
|
default_overall_up_factor = 2**self.num_upsamplers
|
||||||
|
|
||||||
|
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||||
|
forward_upsample_size = False
|
||||||
|
upsample_size = None
|
||||||
|
|
||||||
|
for dim in sample.shape[-2:]:
|
||||||
|
if dim % default_overall_up_factor != 0:
|
||||||
|
# Forward upsample size to force interpolation output size.
|
||||||
|
forward_upsample_size = True
|
||||||
|
break
|
||||||
|
|
||||||
|
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension
|
||||||
|
# expects mask of shape:
|
||||||
|
# [batch, key_tokens]
|
||||||
|
# adds singleton query_tokens dimension:
|
||||||
|
# [batch, 1, key_tokens]
|
||||||
|
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
||||||
|
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
||||||
|
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
||||||
|
if attention_mask is not None:
|
||||||
|
# assume that mask is expressed as:
|
||||||
|
# (1 = keep, 0 = discard)
|
||||||
|
# convert mask into a bias that can be added to attention scores:
|
||||||
|
# (keep = +0, discard = -10000.0)
|
||||||
|
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||||
|
attention_mask = attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||||
|
if encoder_attention_mask is not None:
|
||||||
|
encoder_attention_mask = (
|
||||||
|
1 - encoder_attention_mask.to(sample.dtype)
|
||||||
|
) * -10000.0
|
||||||
|
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
||||||
|
|
||||||
|
# 0. center input if necessary
|
||||||
|
if self.config.center_input_sample:
|
||||||
|
sample = 2 * sample - 1.0
|
||||||
|
|
||||||
|
# 1. time
|
||||||
|
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
|
||||||
|
emb = self.time_embedding(t_emb, timestep_cond)
|
||||||
|
aug_emb = None
|
||||||
|
|
||||||
|
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
|
||||||
|
if class_emb is not None:
|
||||||
|
if self.config.class_embeddings_concat:
|
||||||
|
emb = torch.cat([emb, class_emb], dim=-1)
|
||||||
|
else:
|
||||||
|
emb = emb + class_emb
|
||||||
|
|
||||||
|
aug_emb = self.get_aug_embed(
|
||||||
|
emb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
|
)
|
||||||
|
if self.config.addition_embed_type == "image_hint":
|
||||||
|
aug_emb, hint = aug_emb
|
||||||
|
sample = torch.cat([sample, hint], dim=1)
|
||||||
|
|
||||||
|
emb = emb + aug_emb if aug_emb is not None else emb
|
||||||
|
|
||||||
|
if self.time_embed_act is not None:
|
||||||
|
emb = self.time_embed_act(emb)
|
||||||
|
|
||||||
|
encoder_hidden_states = self.process_encoder_hidden_states(
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. pre-process
|
||||||
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
|
# 2.5 GLIGEN position net
|
||||||
|
if (
|
||||||
|
cross_attention_kwargs is not None
|
||||||
|
and cross_attention_kwargs.get("gligen", None) is not None
|
||||||
|
):
|
||||||
|
cross_attention_kwargs = cross_attention_kwargs.copy()
|
||||||
|
gligen_args = cross_attention_kwargs.pop("gligen")
|
||||||
|
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
||||||
|
|
||||||
|
# 3. down
|
||||||
|
lora_scale = (
|
||||||
|
cross_attention_kwargs.get("scale", 1.0)
|
||||||
|
if cross_attention_kwargs is not None
|
||||||
|
else 1.0
|
||||||
|
)
|
||||||
|
if USE_PEFT_BACKEND:
|
||||||
|
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
||||||
|
scale_lora_layers(self, lora_scale)
|
||||||
|
|
||||||
|
is_controlnet = (
|
||||||
|
mid_block_additional_residual is not None
|
||||||
|
and down_block_additional_residuals is not None
|
||||||
|
)
|
||||||
|
# using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
|
||||||
|
is_adapter = down_intrablock_additional_residuals is not None
|
||||||
|
# maintain backward compatibility for legacy usage, where
|
||||||
|
# T2I-Adapter and ControlNet both use down_block_additional_residuals arg
|
||||||
|
# but can only use one or the other
|
||||||
|
is_brushnet = (
|
||||||
|
down_block_add_samples is not None
|
||||||
|
and mid_block_add_sample is not None
|
||||||
|
and up_block_add_samples is not None
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
not is_adapter
|
||||||
|
and mid_block_additional_residual is None
|
||||||
|
and down_block_additional_residuals is not None
|
||||||
|
):
|
||||||
|
deprecate(
|
||||||
|
"T2I should not use down_block_additional_residuals",
|
||||||
|
"1.3.0",
|
||||||
|
"Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
|
||||||
|
and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
|
||||||
|
for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
|
||||||
|
standard_warn=False,
|
||||||
|
)
|
||||||
|
down_intrablock_additional_residuals = down_block_additional_residuals
|
||||||
|
is_adapter = True
|
||||||
|
|
||||||
|
down_block_res_samples = (sample,)
|
||||||
|
|
||||||
|
if is_brushnet:
|
||||||
|
sample = sample + down_block_add_samples.pop(0)
|
||||||
|
|
||||||
|
for downsample_block in self.down_blocks:
|
||||||
|
if (
|
||||||
|
hasattr(downsample_block, "has_cross_attention")
|
||||||
|
and downsample_block.has_cross_attention
|
||||||
|
):
|
||||||
|
# For t2i-adapter CrossAttnDownBlock2D
|
||||||
|
additional_residuals = {}
|
||||||
|
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||||
|
additional_residuals["additional_residuals"] = (
|
||||||
|
down_intrablock_additional_residuals.pop(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_brushnet and len(down_block_add_samples) > 0:
|
||||||
|
additional_residuals["down_block_add_samples"] = [
|
||||||
|
down_block_add_samples.pop(0)
|
||||||
|
for _ in range(
|
||||||
|
len(downsample_block.resnets)
|
||||||
|
+ (downsample_block.downsamplers != None)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
sample, res_samples = downsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
**additional_residuals,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
additional_residuals = {}
|
||||||
|
if is_brushnet and len(down_block_add_samples) > 0:
|
||||||
|
additional_residuals["down_block_add_samples"] = [
|
||||||
|
down_block_add_samples.pop(0)
|
||||||
|
for _ in range(
|
||||||
|
len(downsample_block.resnets)
|
||||||
|
+ (downsample_block.downsamplers != None)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
sample, res_samples = downsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
scale=lora_scale,
|
||||||
|
**additional_residuals,
|
||||||
|
)
|
||||||
|
if is_adapter and len(down_intrablock_additional_residuals) > 0:
|
||||||
|
sample += down_intrablock_additional_residuals.pop(0)
|
||||||
|
|
||||||
|
down_block_res_samples += res_samples
|
||||||
|
|
||||||
|
if is_controlnet:
|
||||||
|
new_down_block_res_samples = ()
|
||||||
|
|
||||||
|
for down_block_res_sample, down_block_additional_residual in zip(
|
||||||
|
down_block_res_samples, down_block_additional_residuals
|
||||||
|
):
|
||||||
|
down_block_res_sample = (
|
||||||
|
down_block_res_sample + down_block_additional_residual
|
||||||
|
)
|
||||||
|
new_down_block_res_samples = new_down_block_res_samples + (
|
||||||
|
down_block_res_sample,
|
||||||
|
)
|
||||||
|
|
||||||
|
down_block_res_samples = new_down_block_res_samples
|
||||||
|
|
||||||
|
# 4. mid
|
||||||
|
if self.mid_block is not None:
|
||||||
|
if (
|
||||||
|
hasattr(self.mid_block, "has_cross_attention")
|
||||||
|
and self.mid_block.has_cross_attention
|
||||||
|
):
|
||||||
|
sample = self.mid_block(
|
||||||
|
sample,
|
||||||
|
emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample = self.mid_block(sample, emb)
|
||||||
|
|
||||||
|
# To support T2I-Adapter-XL
|
||||||
|
if (
|
||||||
|
is_adapter
|
||||||
|
and len(down_intrablock_additional_residuals) > 0
|
||||||
|
and sample.shape == down_intrablock_additional_residuals[0].shape
|
||||||
|
):
|
||||||
|
sample += down_intrablock_additional_residuals.pop(0)
|
||||||
|
|
||||||
|
if is_controlnet:
|
||||||
|
sample = sample + mid_block_additional_residual
|
||||||
|
|
||||||
|
if is_brushnet:
|
||||||
|
sample = sample + mid_block_add_sample
|
||||||
|
|
||||||
|
# 5. up
|
||||||
|
for i, upsample_block in enumerate(self.up_blocks):
|
||||||
|
is_final_block = i == len(self.up_blocks) - 1
|
||||||
|
|
||||||
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
||||||
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||||
|
|
||||||
|
# if we have not reached the final block and need to forward the
|
||||||
|
# upsample size, we do it here
|
||||||
|
if not is_final_block and forward_upsample_size:
|
||||||
|
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(upsample_block, "has_cross_attention")
|
||||||
|
and upsample_block.has_cross_attention
|
||||||
|
):
|
||||||
|
additional_residuals = {}
|
||||||
|
if is_brushnet and len(up_block_add_samples) > 0:
|
||||||
|
additional_residuals["up_block_add_samples"] = [
|
||||||
|
up_block_add_samples.pop(0)
|
||||||
|
for _ in range(
|
||||||
|
len(upsample_block.resnets)
|
||||||
|
+ (upsample_block.upsamplers != None)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
sample = upsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
|
**additional_residuals,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
additional_residuals = {}
|
||||||
|
if is_brushnet and len(up_block_add_samples) > 0:
|
||||||
|
additional_residuals["up_block_add_samples"] = [
|
||||||
|
up_block_add_samples.pop(0)
|
||||||
|
for _ in range(
|
||||||
|
len(upsample_block.resnets)
|
||||||
|
+ (upsample_block.upsamplers != None)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
sample = upsample_block(
|
||||||
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
res_hidden_states_tuple=res_samples,
|
||||||
|
upsample_size=upsample_size,
|
||||||
|
scale=lora_scale,
|
||||||
|
**additional_residuals,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. post-process
|
||||||
|
if self.conv_norm_out:
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
sample = self.conv_out(sample)
|
||||||
|
|
||||||
|
if USE_PEFT_BACKEND:
|
||||||
|
# remove `lora_scale` from each PEFT layer
|
||||||
|
unscale_lora_layers(self, lora_scale)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (sample,)
|
||||||
|
|
||||||
|
return UNet2DConditionOutput(sample=sample)
|
@ -7,6 +7,8 @@ import numpy as np
|
|||||||
from iopaint.download import scan_models
|
from iopaint.download import scan_models
|
||||||
from iopaint.helper import switch_mps_device
|
from iopaint.helper import switch_mps_device
|
||||||
from iopaint.model import models, ControlNet, SD, SDXL
|
from iopaint.model import models, ControlNet, SD, SDXL
|
||||||
|
from iopaint.model.brushnet.brushnet_wrapper import BrushNetWrapper
|
||||||
|
from iopaint.model.power_paint.power_paint_v2 import PowerPaintV2
|
||||||
from iopaint.model.utils import torch_gc, is_local_files_only
|
from iopaint.model.utils import torch_gc, is_local_files_only
|
||||||
from iopaint.schema import InpaintRequest, ModelInfo, ModelType
|
from iopaint.schema import InpaintRequest, ModelInfo, ModelType
|
||||||
|
|
||||||
@ -28,6 +30,12 @@ class ModelManager:
|
|||||||
):
|
):
|
||||||
controlnet_method = self.available_models[name].controlnets[0]
|
controlnet_method = self.available_models[name].controlnets[0]
|
||||||
self.controlnet_method = controlnet_method
|
self.controlnet_method = controlnet_method
|
||||||
|
|
||||||
|
self.enable_brushnet = kwargs.get("enable_brushnet", False)
|
||||||
|
self.brushnet_method = kwargs.get("brushnet_method", None)
|
||||||
|
|
||||||
|
self.enable_powerpaint_v2 = kwargs.get("enable_powerpaint_v2", False)
|
||||||
|
|
||||||
self.model = self.init_model(name, device, **kwargs)
|
self.model = self.init_model(name, device, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -47,13 +55,22 @@ class ModelManager:
|
|||||||
"model_info": model_info,
|
"model_info": model_info,
|
||||||
"enable_controlnet": self.enable_controlnet,
|
"enable_controlnet": self.enable_controlnet,
|
||||||
"controlnet_method": self.controlnet_method,
|
"controlnet_method": self.controlnet_method,
|
||||||
|
"enable_brushnet": self.enable_brushnet,
|
||||||
|
"brushnet_method": self.brushnet_method,
|
||||||
}
|
}
|
||||||
|
|
||||||
if model_info.support_controlnet and self.enable_controlnet:
|
if model_info.support_controlnet and self.enable_controlnet:
|
||||||
return ControlNet(device, **kwargs)
|
return ControlNet(device, **kwargs)
|
||||||
elif model_info.name in models:
|
|
||||||
|
if model_info.support_brushnet and self.enable_brushnet:
|
||||||
|
return BrushNetWrapper(device, **kwargs)
|
||||||
|
|
||||||
|
if model_info.support_powerpaint_v2 and self.enable_powerpaint_v2:
|
||||||
|
return PowerPaintV2(device, **kwargs)
|
||||||
|
|
||||||
|
if model_info.name in models:
|
||||||
return models[name](device, **kwargs)
|
return models[name](device, **kwargs)
|
||||||
else:
|
|
||||||
if model_info.model_type in [
|
if model_info.model_type in [
|
||||||
ModelType.DIFFUSERS_SD_INPAINT,
|
ModelType.DIFFUSERS_SD_INPAINT,
|
||||||
ModelType.DIFFUSERS_SD,
|
ModelType.DIFFUSERS_SD,
|
||||||
@ -80,8 +97,12 @@ class ModelManager:
|
|||||||
Returns:
|
Returns:
|
||||||
BGR image
|
BGR image
|
||||||
"""
|
"""
|
||||||
|
if config.enable_controlnet:
|
||||||
self.switch_controlnet_method(config)
|
self.switch_controlnet_method(config)
|
||||||
self.enable_disable_freeu(config)
|
if config.enable_brushnet:
|
||||||
|
self.switch_brushnet_method(config)
|
||||||
|
|
||||||
|
self.enable_disable_powerpaint_v2(config)
|
||||||
self.enable_disable_lcm_lora(config)
|
self.enable_disable_lcm_lora(config)
|
||||||
return self.model(image, mask, config).astype(np.uint8)
|
return self.model(image, mask, config).astype(np.uint8)
|
||||||
|
|
||||||
@ -121,6 +142,46 @@ class ModelManager:
|
|||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def switch_brushnet_method(self, config):
|
||||||
|
if not self.available_models[self.name].support_brushnet:
|
||||||
|
return
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.enable_brushnet
|
||||||
|
and config.brushnet_method
|
||||||
|
and self.brushnet_method != config.brushnet_method
|
||||||
|
):
|
||||||
|
old_brushnet_method = self.brushnet_method
|
||||||
|
self.brushnet_method = config.brushnet_method
|
||||||
|
self.model.switch_brushnet_method(config.brushnet_method)
|
||||||
|
logger.info(
|
||||||
|
f"Switch Brushnet method from {old_brushnet_method} to {config.brushnet_method}"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif self.enable_brushnet != config.enable_brushnet:
|
||||||
|
self.enable_brushnet = config.enable_brushnet
|
||||||
|
self.brushnet_method = config.brushnet_method
|
||||||
|
|
||||||
|
pipe_components = {
|
||||||
|
"vae": self.model.model.vae,
|
||||||
|
"text_encoder": self.model.model.text_encoder,
|
||||||
|
"unet": self.model.model.unet,
|
||||||
|
}
|
||||||
|
if hasattr(self.model.model, "text_encoder_2"):
|
||||||
|
pipe_components["text_encoder_2"] = self.model.model.text_encoder_2
|
||||||
|
|
||||||
|
self.model = self.init_model(
|
||||||
|
self.name,
|
||||||
|
switch_mps_device(self.name, self.device),
|
||||||
|
pipe_components=pipe_components,
|
||||||
|
**self.kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not config.enable_brushnet:
|
||||||
|
logger.info("BrushNet Disabled")
|
||||||
|
else:
|
||||||
|
logger.info("BrushNet Enabled")
|
||||||
|
|
||||||
def switch_controlnet_method(self, config):
|
def switch_controlnet_method(self, config):
|
||||||
if not self.available_models[self.name].support_controlnet:
|
if not self.available_models[self.name].support_controlnet:
|
||||||
return
|
return
|
||||||
@ -155,25 +216,28 @@ class ModelManager:
|
|||||||
**self.kwargs,
|
**self.kwargs,
|
||||||
)
|
)
|
||||||
if not config.enable_controlnet:
|
if not config.enable_controlnet:
|
||||||
logger.info(f"Disable controlnet")
|
logger.info("Disable controlnet")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Enable controlnet: {config.controlnet_method}")
|
logger.info(f"Enable controlnet: {config.controlnet_method}")
|
||||||
|
|
||||||
def enable_disable_freeu(self, config: InpaintRequest):
|
def enable_disable_powerpaint_v2(self, config: InpaintRequest):
|
||||||
if str(self.model.device) == "mps":
|
if not self.available_models[self.name].support_powerpaint_v2:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.available_models[self.name].support_freeu:
|
if self.enable_powerpaint_v2 != config.enable_powerpaint_v2:
|
||||||
if config.sd_freeu:
|
self.enable_powerpaint_v2 = config.enable_powerpaint_v2
|
||||||
freeu_config = config.sd_freeu_config
|
pipe_components = {"vae": self.model.model.vae}
|
||||||
self.model.model.enable_freeu(
|
|
||||||
s1=freeu_config.s1,
|
self.model = self.init_model(
|
||||||
s2=freeu_config.s2,
|
self.name,
|
||||||
b1=freeu_config.b1,
|
switch_mps_device(self.name, self.device),
|
||||||
b2=freeu_config.b2,
|
pipe_components=pipe_components,
|
||||||
|
**self.kwargs,
|
||||||
)
|
)
|
||||||
|
if config.enable_powerpaint_v2:
|
||||||
|
logger.info("Enable PowerPaintV2")
|
||||||
else:
|
else:
|
||||||
self.model.model.disable_freeu()
|
logger.info("Disable PowerPaintV2")
|
||||||
|
|
||||||
def enable_disable_lcm_lora(self, config: InpaintRequest):
|
def enable_disable_lcm_lora(self, config: InpaintRequest):
|
||||||
if self.available_models[self.name].support_lcm_lora:
|
if self.available_models[self.name].support_lcm_lora:
|
||||||
|
@ -3,6 +3,8 @@ from enum import Enum
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional, Literal, List
|
from typing import Optional, Literal, List
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from iopaint.const import (
|
from iopaint.const import (
|
||||||
INSTRUCT_PIX2PIX_NAME,
|
INSTRUCT_PIX2PIX_NAME,
|
||||||
KANDINSKY22_NAME,
|
KANDINSKY22_NAME,
|
||||||
@ -11,9 +13,9 @@ from iopaint.const import (
|
|||||||
SDXL_CONTROLNET_CHOICES,
|
SDXL_CONTROLNET_CHOICES,
|
||||||
SD2_CONTROLNET_CHOICES,
|
SD2_CONTROLNET_CHOICES,
|
||||||
SD_CONTROLNET_CHOICES,
|
SD_CONTROLNET_CHOICES,
|
||||||
|
SD_BRUSHNET_CHOICES,
|
||||||
)
|
)
|
||||||
from loguru import logger
|
from pydantic import BaseModel, Field, computed_field, model_validator
|
||||||
from pydantic import BaseModel, Field, field_validator, computed_field
|
|
||||||
|
|
||||||
|
|
||||||
class ModelType(str, Enum):
|
class ModelType(str, Enum):
|
||||||
@ -63,6 +65,13 @@ class ModelInfo(BaseModel):
|
|||||||
return SD_CONTROLNET_CHOICES
|
return SD_CONTROLNET_CHOICES
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
@computed_field
|
||||||
|
@property
|
||||||
|
def brushnets(self) -> List[str]:
|
||||||
|
if self.model_type in [ModelType.DIFFUSERS_SD]:
|
||||||
|
return SD_BRUSHNET_CHOICES
|
||||||
|
return []
|
||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def support_strength(self) -> bool:
|
def support_strength(self) -> bool:
|
||||||
@ -105,13 +114,21 @@ class ModelInfo(BaseModel):
|
|||||||
|
|
||||||
@computed_field
|
@computed_field
|
||||||
@property
|
@property
|
||||||
def support_freeu(self) -> bool:
|
def support_brushnet(self) -> bool:
|
||||||
return self.model_type in [
|
return self.model_type in [
|
||||||
ModelType.DIFFUSERS_SD,
|
ModelType.DIFFUSERS_SD,
|
||||||
ModelType.DIFFUSERS_SDXL,
|
]
|
||||||
ModelType.DIFFUSERS_SD_INPAINT,
|
|
||||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
@computed_field
|
||||||
] or self.name in [INSTRUCT_PIX2PIX_NAME]
|
@property
|
||||||
|
def support_powerpaint_v2(self) -> bool:
|
||||||
|
return (
|
||||||
|
self.model_type
|
||||||
|
in [
|
||||||
|
ModelType.DIFFUSERS_SD,
|
||||||
|
]
|
||||||
|
and self.name != POWERPAINT_NAME
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Choices(str, Enum):
|
class Choices(str, Enum):
|
||||||
@ -202,15 +219,9 @@ class SDSampler(str, Enum):
|
|||||||
lcm = "LCM"
|
lcm = "LCM"
|
||||||
|
|
||||||
|
|
||||||
class FREEUConfig(BaseModel):
|
class PowerPaintTask(Choices):
|
||||||
s1: float = 0.9
|
|
||||||
s2: float = 0.2
|
|
||||||
b1: float = 1.2
|
|
||||||
b2: float = 1.4
|
|
||||||
|
|
||||||
|
|
||||||
class PowerPaintTask(str, Enum):
|
|
||||||
text_guided = "text-guided"
|
text_guided = "text-guided"
|
||||||
|
context_aware = "context-aware"
|
||||||
shape_guided = "shape-guided"
|
shape_guided = "shape-guided"
|
||||||
object_remove = "object-remove"
|
object_remove = "object-remove"
|
||||||
outpainting = "outpainting"
|
outpainting = "outpainting"
|
||||||
@ -328,12 +339,6 @@ class InpaintRequest(BaseModel):
|
|||||||
sd_outpainting_softness: float = Field(20.0)
|
sd_outpainting_softness: float = Field(20.0)
|
||||||
sd_outpainting_space: float = Field(20.0)
|
sd_outpainting_space: float = Field(20.0)
|
||||||
|
|
||||||
sd_freeu: bool = Field(
|
|
||||||
False,
|
|
||||||
description="Enable freeu mode. https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu",
|
|
||||||
)
|
|
||||||
sd_freeu_config: FREEUConfig = FREEUConfig()
|
|
||||||
|
|
||||||
sd_lcm_lora: bool = Field(
|
sd_lcm_lora: bool = Field(
|
||||||
False,
|
False,
|
||||||
description="Enable lcm-lora mode. https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm#texttoimage",
|
description="Enable lcm-lora mode. https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm#texttoimage",
|
||||||
@ -369,7 +374,15 @@ class InpaintRequest(BaseModel):
|
|||||||
"lllyasviel/control_v11p_sd15_canny", description="Controlnet method"
|
"lllyasviel/control_v11p_sd15_canny", description="Controlnet method"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# BrushNet
|
||||||
|
enable_brushnet: bool = Field(False, description="Enable brushnet")
|
||||||
|
brushnet_method: str = Field(SD_BRUSHNET_CHOICES[0], description="Brushnet method")
|
||||||
|
brushnet_conditioning_scale: float = Field(
|
||||||
|
1.0, description="brushnet conditioning scale", ge=0.0, le=1.0
|
||||||
|
)
|
||||||
|
|
||||||
# PowerPaint
|
# PowerPaint
|
||||||
|
enable_powerpaint_v2: bool = Field(False, description="Enable PowerPaint v2")
|
||||||
powerpaint_task: PowerPaintTask = Field(
|
powerpaint_task: PowerPaintTask = Field(
|
||||||
PowerPaintTask.text_guided, description="PowerPaint task"
|
PowerPaintTask.text_guided, description="PowerPaint task"
|
||||||
)
|
)
|
||||||
@ -380,31 +393,34 @@ class InpaintRequest(BaseModel):
|
|||||||
le=1.0,
|
le=1.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@field_validator("sd_seed")
|
@model_validator(mode="after")
|
||||||
@classmethod
|
def validate_field(cls, values: "InpaintRequest"):
|
||||||
def sd_seed_validator(cls, v: int) -> int:
|
if values.sd_seed == -1:
|
||||||
if v == -1:
|
values.sd_seed = random.randint(1, 99999999)
|
||||||
return random.randint(1, 99999999)
|
logger.info(f"Generate random seed: {values.sd_seed}")
|
||||||
return v
|
|
||||||
|
|
||||||
@field_validator("controlnet_conditioning_scale")
|
if values.use_extender and values.enable_controlnet:
|
||||||
@classmethod
|
logger.info("Extender is enabled, set controlnet_conditioning_scale=0")
|
||||||
def validate_field(cls, v: float, values):
|
values.controlnet_conditioning_scale = 0
|
||||||
use_extender = values.data["use_extender"]
|
|
||||||
enable_controlnet = values.data["enable_controlnet"]
|
|
||||||
if use_extender and enable_controlnet:
|
|
||||||
logger.info(f"Extender is enabled, set controlnet_conditioning_scale=0")
|
|
||||||
return 0
|
|
||||||
return v
|
|
||||||
|
|
||||||
@field_validator("sd_strength")
|
if values.use_extender:
|
||||||
@classmethod
|
logger.info("Extender is enabled, set sd_strength=1")
|
||||||
def validate_sd_strength(cls, v: float, values):
|
values.sd_strength = 1.0
|
||||||
use_extender = values.data["use_extender"]
|
|
||||||
if use_extender:
|
if values.enable_brushnet:
|
||||||
logger.info(f"Extender is enabled, set sd_strength=1")
|
logger.info("BrushNet is enabled, set enable_controlnet=False")
|
||||||
return 1.0
|
if values.enable_controlnet:
|
||||||
return v
|
values.enable_controlnet = False
|
||||||
|
if values.sd_lcm_lora:
|
||||||
|
logger.info("BrushNet is enabled, set sd_lcm_lora=False")
|
||||||
|
values.sd_lcm_lora = False
|
||||||
|
|
||||||
|
if values.enable_controlnet:
|
||||||
|
logger.info("ControlNet is enabled, set enable_brushnet=False")
|
||||||
|
if values.enable_brushnet:
|
||||||
|
values.enable_brushnet = False
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
class RunPluginRequest(BaseModel):
|
class RunPluginRequest(BaseModel):
|
||||||
|
110
iopaint/tests/test_brushnet.py
Normal file
110
iopaint/tests/test_brushnet.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from iopaint.const import SD_BRUSHNET_CHOICES
|
||||||
|
from iopaint.tests.utils import check_device, get_config, assert_equal
|
||||||
|
|
||||||
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from iopaint.model_manager import ModelManager
|
||||||
|
from iopaint.schema import HDStrategy, SDSampler, PowerPaintTask
|
||||||
|
|
||||||
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
|
save_dir = current_dir / "result"
|
||||||
|
save_dir.mkdir(exist_ok=True, parents=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||||
|
@pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m_karras])
|
||||||
|
def test_runway_brushnet(device, sampler):
|
||||||
|
sd_steps = check_device(device)
|
||||||
|
model = ModelManager(
|
||||||
|
name="runwayml/stable-diffusion-v1-5",
|
||||||
|
device=torch.device(device),
|
||||||
|
disable_nsfw=True,
|
||||||
|
sd_cpu_textencoder=False,
|
||||||
|
)
|
||||||
|
cfg = get_config(
|
||||||
|
strategy=HDStrategy.ORIGINAL,
|
||||||
|
prompt="face of a fox, sitting on a bench",
|
||||||
|
sd_steps=sd_steps,
|
||||||
|
sd_guidance_scale=7.5,
|
||||||
|
enable_brushnet=True,
|
||||||
|
brushnet_method=SD_BRUSHNET_CHOICES[0],
|
||||||
|
)
|
||||||
|
cfg.sd_sampler = sampler
|
||||||
|
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"brushnet_random_mask_{device}.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||||
|
@pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m])
|
||||||
|
def test_runway_powerpaint_v2(device, sampler):
|
||||||
|
sd_steps = check_device(device)
|
||||||
|
model = ModelManager(
|
||||||
|
name="runwayml/stable-diffusion-v1-5",
|
||||||
|
device=torch.device(device),
|
||||||
|
disable_nsfw=True,
|
||||||
|
sd_cpu_textencoder=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
tasks = {
|
||||||
|
PowerPaintTask.text_guided: {
|
||||||
|
"prompt": "face of a fox, sitting on a bench",
|
||||||
|
"scale": 7.5,
|
||||||
|
},
|
||||||
|
PowerPaintTask.context_aware: {
|
||||||
|
"prompt": "face of a fox, sitting on a bench",
|
||||||
|
"scale": 7.5,
|
||||||
|
},
|
||||||
|
PowerPaintTask.shape_guided: {
|
||||||
|
"prompt": "face of a fox, sitting on a bench",
|
||||||
|
"scale": 7.5,
|
||||||
|
},
|
||||||
|
PowerPaintTask.object_remove: {
|
||||||
|
"prompt": "",
|
||||||
|
"scale": 12,
|
||||||
|
},
|
||||||
|
PowerPaintTask.outpainting: {
|
||||||
|
"prompt": "",
|
||||||
|
"scale": 7.5,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for task, data in tasks.items():
|
||||||
|
cfg = get_config(
|
||||||
|
strategy=HDStrategy.ORIGINAL,
|
||||||
|
prompt=data["prompt"],
|
||||||
|
negative_prompt="out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature",
|
||||||
|
sd_steps=sd_steps,
|
||||||
|
sd_guidance_scale=data["scale"],
|
||||||
|
enable_powerpaint_v2=True,
|
||||||
|
powerpaint_task=task,
|
||||||
|
sd_sampler=sampler,
|
||||||
|
sd_mask_blur=11,
|
||||||
|
sd_seed=42,
|
||||||
|
# sd_keep_unmasked_area=False
|
||||||
|
)
|
||||||
|
if task == PowerPaintTask.outpainting:
|
||||||
|
cfg.use_extender = True
|
||||||
|
cfg.extender_x = -128
|
||||||
|
cfg.extender_y = -128
|
||||||
|
cfg.extender_width = 768
|
||||||
|
cfg.extender_height = 768
|
||||||
|
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"powerpaint_v2_{device}_{task}.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
)
|
@ -10,7 +10,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from iopaint.model_manager import ModelManager
|
from iopaint.model_manager import ModelManager
|
||||||
from iopaint.schema import HDStrategy, SDSampler, FREEUConfig
|
from iopaint.schema import HDStrategy, SDSampler
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||||
@ -75,35 +75,6 @@ def test_runway_sd_lcm_lora_low_mem(device, sampler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
|
||||||
def test_runway_sd_freeu(device, sampler):
|
|
||||||
sd_steps = check_device(device)
|
|
||||||
model = ModelManager(
|
|
||||||
name="runwayml/stable-diffusion-inpainting",
|
|
||||||
device=torch.device(device),
|
|
||||||
disable_nsfw=True,
|
|
||||||
sd_cpu_textencoder=False,
|
|
||||||
low_mem=True,
|
|
||||||
)
|
|
||||||
cfg = get_config(
|
|
||||||
strategy=HDStrategy.ORIGINAL,
|
|
||||||
prompt="face of a fox, sitting on a bench",
|
|
||||||
sd_steps=sd_steps,
|
|
||||||
sd_guidance_scale=7.5,
|
|
||||||
sd_freeu=True,
|
|
||||||
sd_freeu_config=FREEUConfig(),
|
|
||||||
)
|
|
||||||
cfg.sd_sampler = sampler
|
|
||||||
|
|
||||||
assert_equal(
|
|
||||||
model,
|
|
||||||
cfg,
|
|
||||||
f"runway_sd_1_5_freeu_device_{device}_low_mem.png",
|
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
|
@ -3,18 +3,17 @@ import os
|
|||||||
from iopaint.tests.utils import current_dir, check_device
|
from iopaint.tests.utils import current_dir, check_device
|
||||||
|
|
||||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from iopaint.model_manager import ModelManager
|
from iopaint.model_manager import ModelManager
|
||||||
from iopaint.schema import HDStrategy, SDSampler
|
from iopaint.schema import SDSampler
|
||||||
from iopaint.tests.test_model import get_config, assert_equal
|
from iopaint.tests.test_model import get_config, assert_equal
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("name", ["runwayml/stable-diffusion-inpainting"])
|
@pytest.mark.parametrize("name", ["runwayml/stable-diffusion-inpainting"])
|
||||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"rect",
|
"rect",
|
||||||
[
|
[
|
||||||
@ -23,7 +22,7 @@ from iopaint.tests.test_model import get_config, assert_equal
|
|||||||
[128, 0, 512 - 128 + 100, 512],
|
[128, 0, 512 - 128 + 100, 512],
|
||||||
[-100, 0, 512 - 128 + 100, 512],
|
[-100, 0, 512 - 128 + 100, 512],
|
||||||
[0, 0, 512, 512 + 200],
|
[0, 0, 512, 512 + 200],
|
||||||
[0, 0, 512 + 200, 512],
|
[256, 0, 512 + 200, 512],
|
||||||
[-100, -100, 512 + 200, 512 + 200],
|
[-100, -100, 512 + 200, 512 + 200],
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -58,7 +57,7 @@ def test_outpainting(name, device, rect):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("name", ["kandinsky-community/kandinsky-2-2-decoder-inpaint"])
|
@pytest.mark.parametrize("name", ["kandinsky-community/kandinsky-2-2-decoder-inpaint"])
|
||||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"rect",
|
"rect",
|
||||||
[
|
[
|
||||||
@ -99,7 +98,7 @@ def test_kandinsky_outpainting(name, device, rect):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("name", ["Sanster/PowerPaint-V1-stable-diffusion-inpainting"])
|
@pytest.mark.parametrize("name", ["Sanster/PowerPaint-V1-stable-diffusion-inpainting"])
|
||||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"rect",
|
"rect",
|
||||||
[
|
[
|
||||||
@ -114,7 +113,7 @@ def test_powerpaint_outpainting(name, device, rect):
|
|||||||
device=torch.device(device),
|
device=torch.device(device),
|
||||||
disable_nsfw=True,
|
disable_nsfw=True,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
low_mem=True
|
low_mem=True,
|
||||||
)
|
)
|
||||||
cfg = get_config(
|
cfg = get_config(
|
||||||
prompt="a dog sitting on a bench in the park",
|
prompt="a dog sitting on a bench in the park",
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import cv2
|
import cv2
|
||||||
import pytest
|
import pytest
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from iopaint.helper import encode_pil_to_base64
|
||||||
|
|
||||||
from iopaint.model_manager import ModelManager
|
from iopaint.model_manager import ModelManager
|
||||||
from iopaint.schema import HDStrategy
|
from iopaint.schema import HDStrategy
|
||||||
@ -34,7 +35,9 @@ def assert_equal(
|
|||||||
)
|
)
|
||||||
|
|
||||||
print(f"Input image shape: {img.shape}, example_image: {example_image.shape}")
|
print(f"Input image shape: {img.shape}, example_image: {example_image.shape}")
|
||||||
config.paint_by_example_example_image = Image.fromarray(example_image)
|
config.paint_by_example_example_image = encode_pil_to_base64(
|
||||||
|
Image.fromarray(example_image), 100, {}
|
||||||
|
).decode("utf-8")
|
||||||
res = model(img, mask, config)
|
res = model(img, mask, config)
|
||||||
cv2.imwrite(str(save_dir / save_name), res)
|
cv2.imwrite(str(save_dir / save_name), res)
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from iopaint.model_manager import ModelManager
|
from iopaint.model_manager import ModelManager
|
||||||
from iopaint.schema import HDStrategy, SDSampler, FREEUConfig
|
from iopaint.schema import HDStrategy, SDSampler
|
||||||
|
|
||||||
current_dir = Path(__file__).parent.absolute().resolve()
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
save_dir = current_dir / "result"
|
save_dir = current_dir / "result"
|
||||||
@ -90,35 +90,6 @@ def test_runway_sd_lcm_lora(device, sampler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
|
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
|
||||||
def test_runway_sd_freeu(device, sampler):
|
|
||||||
sd_steps = check_device(device)
|
|
||||||
model = ModelManager(
|
|
||||||
name="runwayml/stable-diffusion-inpainting",
|
|
||||||
device=torch.device(device),
|
|
||||||
disable_nsfw=True,
|
|
||||||
sd_cpu_textencoder=False,
|
|
||||||
)
|
|
||||||
cfg = get_config(
|
|
||||||
strategy=HDStrategy.ORIGINAL,
|
|
||||||
prompt="face of a fox, sitting on a bench",
|
|
||||||
sd_steps=sd_steps,
|
|
||||||
sd_guidance_scale=7.5,
|
|
||||||
sd_freeu=True,
|
|
||||||
sd_freeu_config=FREEUConfig(),
|
|
||||||
)
|
|
||||||
cfg.sd_sampler = sampler
|
|
||||||
|
|
||||||
assert_equal(
|
|
||||||
model,
|
|
||||||
cfg,
|
|
||||||
f"runway_sd_1_5_freeu_device_{device}.png",
|
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||||
|
@ -8,7 +8,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from iopaint.model_manager import ModelManager
|
from iopaint.model_manager import ModelManager
|
||||||
from iopaint.schema import HDStrategy, SDSampler, FREEUConfig
|
from iopaint.schema import HDStrategy, SDSampler
|
||||||
from iopaint.tests.test_model import get_config, assert_equal
|
from iopaint.tests.test_model import get_config, assert_equal
|
||||||
|
|
||||||
|
|
||||||
@ -76,60 +76,6 @@ def test_sdxl_cpu_text_encoder(device, strategy, sampler):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
|
||||||
def test_sdxl_lcm_lora_and_freeu(device, strategy, sampler):
|
|
||||||
sd_steps = check_device(device)
|
|
||||||
|
|
||||||
model = ModelManager(
|
|
||||||
name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
|
|
||||||
device=torch.device(device),
|
|
||||||
disable_nsfw=True,
|
|
||||||
sd_cpu_textencoder=False,
|
|
||||||
)
|
|
||||||
cfg = get_config(
|
|
||||||
strategy=strategy,
|
|
||||||
prompt="face of a fox, sitting on a bench",
|
|
||||||
sd_steps=sd_steps,
|
|
||||||
sd_strength=1.0,
|
|
||||||
sd_guidance_scale=2.0,
|
|
||||||
sd_lcm_lora=True,
|
|
||||||
)
|
|
||||||
cfg.sd_sampler = sampler
|
|
||||||
|
|
||||||
name = f"device_{device}_{sampler}"
|
|
||||||
|
|
||||||
assert_equal(
|
|
||||||
model,
|
|
||||||
cfg,
|
|
||||||
f"sdxl_{name}_lcm_lora.png",
|
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
|
||||||
fx=2,
|
|
||||||
fy=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = get_config(
|
|
||||||
strategy=strategy,
|
|
||||||
prompt="face of a fox, sitting on a bench",
|
|
||||||
sd_steps=sd_steps,
|
|
||||||
sd_guidance_scale=7.5,
|
|
||||||
sd_freeu=True,
|
|
||||||
sd_freeu_config=FREEUConfig(),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert_equal(
|
|
||||||
model,
|
|
||||||
cfg,
|
|
||||||
f"sdxl_{name}_freeu_device_{device}.png",
|
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
|
||||||
fx=2,
|
|
||||||
fy=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
@pytest.mark.parametrize("device", ["cuda", "mps"])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"rect",
|
"rect",
|
||||||
|
@ -3,9 +3,8 @@ import cv2
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from iopaint.helper import encode_pil_to_base64
|
|
||||||
from iopaint.schema import LDMSampler, HDStrategy, InpaintRequest, SDSampler
|
from iopaint.schema import LDMSampler, HDStrategy, InpaintRequest, SDSampler
|
||||||
from PIL import Image
|
import numpy as np
|
||||||
|
|
||||||
current_dir = Path(__file__).parent.absolute().resolve()
|
current_dir = Path(__file__).parent.absolute().resolve()
|
||||||
save_dir = current_dir / "result"
|
save_dir = current_dir / "result"
|
||||||
@ -32,6 +31,7 @@ def assert_equal(
|
|||||||
):
|
):
|
||||||
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
||||||
print(f"Input image shape: {img.shape}")
|
print(f"Input image shape: {img.shape}")
|
||||||
|
|
||||||
res = model(img, mask, config)
|
res = model(img, mask, config)
|
||||||
ok = cv2.imwrite(
|
ok = cv2.imwrite(
|
||||||
str(save_dir / gt_name),
|
str(save_dir / gt_name),
|
||||||
|
2
setup.py
2
setup.py
@ -28,7 +28,7 @@ def load_requirements():
|
|||||||
# https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files
|
# https://setuptools.readthedocs.io/en/latest/setuptools.html#including-data-files
|
||||||
setuptools.setup(
|
setuptools.setup(
|
||||||
name="IOPaint",
|
name="IOPaint",
|
||||||
version="1.2.3",
|
version="1.3.0",
|
||||||
author="PanicByte",
|
author="PanicByte",
|
||||||
author_email="cwq1913@gmail.com",
|
author_email="cwq1913@gmail.com",
|
||||||
description="Image inpainting, outpainting tool powered by SOTA AI Model",
|
description="Image inpainting, outpainting tool powered by SOTA AI Model",
|
||||||
|
@ -62,7 +62,7 @@ const CV2Options = () => {
|
|||||||
/>
|
/>
|
||||||
<NumberInput
|
<NumberInput
|
||||||
id="cv2-radius"
|
id="cv2-radius"
|
||||||
className="w-[60px] rounded-full"
|
className="w-[50px] rounded-full"
|
||||||
numberValue={settings.cv2Radius}
|
numberValue={settings.cv2Radius}
|
||||||
allowFloat={false}
|
allowFloat={false}
|
||||||
onNumberValueChange={(val) => {
|
onNumberValueChange={(val) => {
|
||||||
|
@ -59,6 +59,10 @@ const DiffusionOptions = () => {
|
|||||||
updateExtenderDirection,
|
updateExtenderDirection,
|
||||||
adjustMask,
|
adjustMask,
|
||||||
clearMask,
|
clearMask,
|
||||||
|
updateEnablePowerPaintV2,
|
||||||
|
updateEnableBrushNet,
|
||||||
|
updateEnableControlnet,
|
||||||
|
updateLCMLora,
|
||||||
] = useStore((state) => [
|
] = useStore((state) => [
|
||||||
state.serverConfig.samplers,
|
state.serverConfig.samplers,
|
||||||
state.settings,
|
state.settings,
|
||||||
@ -71,6 +75,10 @@ const DiffusionOptions = () => {
|
|||||||
state.updateExtenderDirection,
|
state.updateExtenderDirection,
|
||||||
state.adjustMask,
|
state.adjustMask,
|
||||||
state.clearMask,
|
state.clearMask,
|
||||||
|
state.updateEnablePowerPaintV2,
|
||||||
|
state.updateEnableBrushNet,
|
||||||
|
state.updateEnableControlnet,
|
||||||
|
state.updateLCMLora,
|
||||||
])
|
])
|
||||||
const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile)
|
const [exampleImage, isExampleImageLoaded] = useImage(paintByExampleFile)
|
||||||
const negativePromptRef = useRef(null)
|
const negativePromptRef = useRef(null)
|
||||||
@ -109,28 +117,109 @@ const DiffusionOptions = () => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const renderBrushNetSetting = () => {
|
||||||
|
if (!settings.model.support_brushnet) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
let toolTip =
|
||||||
|
"BrushNet is a plug-and-play image inpainting model works on any SD1.5 base models."
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex flex-col gap-4">
|
||||||
|
<div className="flex flex-col gap-4">
|
||||||
|
<RowContainer>
|
||||||
|
<LabelTitle
|
||||||
|
text="BrushNet"
|
||||||
|
url="https://github.com/TencentARC/BrushNet"
|
||||||
|
toolTip={toolTip}
|
||||||
|
/>
|
||||||
|
<Switch
|
||||||
|
id="brushnet"
|
||||||
|
checked={settings.enableBrushNet}
|
||||||
|
onCheckedChange={(value) => {
|
||||||
|
updateEnableBrushNet(value)
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</RowContainer>
|
||||||
|
{/* <RowContainer>
|
||||||
|
<Slider
|
||||||
|
defaultValue={[100]}
|
||||||
|
className="w-[180px]"
|
||||||
|
min={1}
|
||||||
|
max={100}
|
||||||
|
step={1}
|
||||||
|
disabled={!settings.enableBrushNet || disable}
|
||||||
|
value={[Math.floor(settings.brushnetConditioningScale * 100)]}
|
||||||
|
onValueChange={(vals) =>
|
||||||
|
updateSettings({ brushnetConditioningScale: vals[0] / 100 })
|
||||||
|
}
|
||||||
|
/>
|
||||||
|
<NumberInput
|
||||||
|
id="brushnet-weight"
|
||||||
|
className="w-[50px] rounded-full"
|
||||||
|
numberValue={settings.brushnetConditioningScale}
|
||||||
|
allowFloat={false}
|
||||||
|
onNumberValueChange={(val) => {
|
||||||
|
updateSettings({ brushnetConditioningScale: val })
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</RowContainer> */}
|
||||||
|
|
||||||
|
<RowContainer>
|
||||||
|
<Select
|
||||||
|
defaultValue={settings.brushnetMethod}
|
||||||
|
value={settings.brushnetMethod}
|
||||||
|
onValueChange={(value) => {
|
||||||
|
updateSettings({ brushnetMethod: value })
|
||||||
|
}}
|
||||||
|
disabled={!settings.enableBrushNet}
|
||||||
|
>
|
||||||
|
<SelectTrigger>
|
||||||
|
<SelectValue placeholder="Select brushnet model" />
|
||||||
|
</SelectTrigger>
|
||||||
|
<SelectContent align="end">
|
||||||
|
<SelectGroup>
|
||||||
|
{Object.values(settings.model.brushnets).map((method) => (
|
||||||
|
<SelectItem key={method} value={method}>
|
||||||
|
{method.split("/")[1]}
|
||||||
|
</SelectItem>
|
||||||
|
))}
|
||||||
|
</SelectGroup>
|
||||||
|
</SelectContent>
|
||||||
|
</Select>
|
||||||
|
</RowContainer>
|
||||||
|
</div>
|
||||||
|
<Separator />
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
const renderConterNetSetting = () => {
|
const renderConterNetSetting = () => {
|
||||||
if (!settings.model.support_controlnet) {
|
if (!settings.model.support_controlnet) {
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let toolTip =
|
||||||
|
"Using an additional conditioning image to control how an image is generated"
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col gap-4">
|
<div className="flex flex-col gap-4">
|
||||||
<div className="flex flex-col gap-4">
|
<div className="flex flex-col gap-4">
|
||||||
<div className="flex justify-between items-center pr-2">
|
<RowContainer>
|
||||||
<LabelTitle
|
<LabelTitle
|
||||||
text="ControlNet"
|
text="ControlNet"
|
||||||
toolTip="Using an additional conditioning image to control how an image is generated"
|
|
||||||
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inpaint#controlnet"
|
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inpaint#controlnet"
|
||||||
|
toolTip={toolTip}
|
||||||
/>
|
/>
|
||||||
<Switch
|
<Switch
|
||||||
id="controlnet"
|
id="controlnet"
|
||||||
checked={settings.enableControlnet}
|
checked={settings.enableControlnet}
|
||||||
onCheckedChange={(value) => {
|
onCheckedChange={(value) => {
|
||||||
updateSettings({ enableControlnet: value })
|
updateEnableControlnet(value)
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</div>
|
</RowContainer>
|
||||||
|
|
||||||
<div className="flex flex-col gap-1">
|
<div className="flex flex-col gap-1">
|
||||||
<RowContainer>
|
<RowContainer>
|
||||||
@ -148,7 +237,7 @@ const DiffusionOptions = () => {
|
|||||||
/>
|
/>
|
||||||
<NumberInput
|
<NumberInput
|
||||||
id="controlnet-weight"
|
id="controlnet-weight"
|
||||||
className="w-[60px] rounded-full"
|
className="w-[50px] rounded-full"
|
||||||
disabled={!settings.enableControlnet}
|
disabled={!settings.enableControlnet}
|
||||||
numberValue={settings.controlnetConditioningScale}
|
numberValue={settings.controlnetConditioningScale}
|
||||||
allowFloat={false}
|
allowFloat={false}
|
||||||
@ -159,7 +248,7 @@ const DiffusionOptions = () => {
|
|||||||
</RowContainer>
|
</RowContainer>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div className="pr-2">
|
<RowContainer>
|
||||||
<Select
|
<Select
|
||||||
defaultValue={settings.controlnetMethod}
|
defaultValue={settings.controlnetMethod}
|
||||||
value={settings.controlnetMethod}
|
value={settings.controlnetMethod}
|
||||||
@ -181,7 +270,7 @@ const DiffusionOptions = () => {
|
|||||||
</SelectGroup>
|
</SelectGroup>
|
||||||
</SelectContent>
|
</SelectContent>
|
||||||
</Select>
|
</Select>
|
||||||
</div>
|
</RowContainer>
|
||||||
</div>
|
</div>
|
||||||
<Separator />
|
<Separator />
|
||||||
</div>
|
</div>
|
||||||
@ -193,19 +282,22 @@ const DiffusionOptions = () => {
|
|||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let toolTip =
|
||||||
|
"Enable quality image generation in typically 2-8 steps. Suggest disabling guidance_scale by setting it to 0. You can also try values between 1.0 and 2.0. When LCM Lora is enabled, LCMSampler will be used automatically."
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<RowContainer>
|
<RowContainer>
|
||||||
<LabelTitle
|
<LabelTitle
|
||||||
text="LCM LoRA"
|
text="LCM LoRA"
|
||||||
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm_lora"
|
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm_lora"
|
||||||
toolTip="Enable quality image generation in typically 2-4 steps. Suggest disabling guidance_scale by setting it to 0. You can also try values between 1.0 and 2.0. When LCM Lora is enabled, LCMSampler will be used automatically."
|
toolTip={toolTip}
|
||||||
/>
|
/>
|
||||||
<Switch
|
<Switch
|
||||||
id="lcm-lora"
|
id="lcm-lora"
|
||||||
checked={settings.enableLCMLora}
|
checked={settings.enableLCMLora}
|
||||||
onCheckedChange={(value) => {
|
onCheckedChange={(value) => {
|
||||||
updateSettings({ enableLCMLora: value })
|
updateLCMLora(value)
|
||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</RowContainer>
|
</RowContainer>
|
||||||
@ -214,115 +306,6 @@ const DiffusionOptions = () => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
const renderFreeu = () => {
|
|
||||||
if (!settings.model.support_freeu) {
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="flex flex-col gap-4">
|
|
||||||
<div className="flex justify-between items-center pr-2">
|
|
||||||
<LabelTitle
|
|
||||||
text="FreeU"
|
|
||||||
toolTip="FreeU is a technique for improving image quality. Different models may require different FreeU-specific hyperparameters, which can be viewed in the more info section."
|
|
||||||
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/freeu"
|
|
||||||
/>
|
|
||||||
<Switch
|
|
||||||
id="freeu"
|
|
||||||
checked={settings.enableFreeu}
|
|
||||||
onCheckedChange={(value) => {
|
|
||||||
updateSettings({ enableFreeu: value })
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<div className="flex flex-col gap-4">
|
|
||||||
<div className="flex justify-center gap-6">
|
|
||||||
<div className="flex gap-2 items-center justify-center">
|
|
||||||
<LabelTitle
|
|
||||||
htmlFor="freeu-s1"
|
|
||||||
text="s1"
|
|
||||||
disabled={!settings.enableFreeu}
|
|
||||||
/>
|
|
||||||
<NumberInput
|
|
||||||
id="freeu-s1"
|
|
||||||
className="w-14"
|
|
||||||
disabled={!settings.enableFreeu}
|
|
||||||
numberValue={settings.freeuConfig.s1}
|
|
||||||
allowFloat
|
|
||||||
onNumberValueChange={(value) => {
|
|
||||||
updateSettings({
|
|
||||||
freeuConfig: { ...settings.freeuConfig, s1: value },
|
|
||||||
})
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<div className="flex gap-2 items-center justify-center">
|
|
||||||
<LabelTitle
|
|
||||||
htmlFor="freeu-s2"
|
|
||||||
text="s2"
|
|
||||||
disabled={!settings.enableFreeu}
|
|
||||||
/>
|
|
||||||
<NumberInput
|
|
||||||
id="freeu-s2"
|
|
||||||
className="w-14"
|
|
||||||
disabled={!settings.enableFreeu}
|
|
||||||
numberValue={settings.freeuConfig.s2}
|
|
||||||
allowFloat
|
|
||||||
onNumberValueChange={(value) => {
|
|
||||||
updateSettings({
|
|
||||||
freeuConfig: { ...settings.freeuConfig, s2: value },
|
|
||||||
})
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className="flex justify-center gap-6">
|
|
||||||
<div className="flex gap-2 items-center justify-center">
|
|
||||||
<LabelTitle
|
|
||||||
htmlFor="freeu-b1"
|
|
||||||
text="b1"
|
|
||||||
disabled={!settings.enableFreeu}
|
|
||||||
/>
|
|
||||||
<NumberInput
|
|
||||||
id="freeu-b1"
|
|
||||||
className="w-14"
|
|
||||||
disabled={!settings.enableFreeu}
|
|
||||||
numberValue={settings.freeuConfig.b1}
|
|
||||||
allowFloat
|
|
||||||
onNumberValueChange={(value) => {
|
|
||||||
updateSettings({
|
|
||||||
freeuConfig: { ...settings.freeuConfig, b1: value },
|
|
||||||
})
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<div className="flex gap-2 items-center justify-center">
|
|
||||||
<LabelTitle
|
|
||||||
htmlFor="freeu-b2"
|
|
||||||
text="b2"
|
|
||||||
disabled={!settings.enableFreeu}
|
|
||||||
/>
|
|
||||||
<NumberInput
|
|
||||||
id="freeu-b2"
|
|
||||||
className="w-14"
|
|
||||||
disabled={!settings.enableFreeu}
|
|
||||||
numberValue={settings.freeuConfig.b2}
|
|
||||||
allowFloat
|
|
||||||
onNumberValueChange={(value) => {
|
|
||||||
updateSettings({
|
|
||||||
freeuConfig: { ...settings.freeuConfig, b2: value },
|
|
||||||
})
|
|
||||||
}}
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<Separator />
|
|
||||||
</div>
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
const renderNegativePrompt = () => {
|
const renderNegativePrompt = () => {
|
||||||
if (!settings.model.need_prompt) {
|
if (!settings.model.need_prompt) {
|
||||||
return null
|
return null
|
||||||
@ -427,7 +410,7 @@ const DiffusionOptions = () => {
|
|||||||
/>
|
/>
|
||||||
<NumberInput
|
<NumberInput
|
||||||
id="image-guidance-scale"
|
id="image-guidance-scale"
|
||||||
className="w-[60px] rounded-full"
|
className="w-[50px] rounded-full"
|
||||||
numberValue={settings.p2pImageGuidanceScale}
|
numberValue={settings.p2pImageGuidanceScale}
|
||||||
allowFloat
|
allowFloat
|
||||||
onNumberValueChange={(val) => {
|
onNumberValueChange={(val) => {
|
||||||
@ -444,16 +427,22 @@ const DiffusionOptions = () => {
|
|||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let toolTip =
|
||||||
|
"Strength is a measure of how much noise is added to the base image, which influences how similar the output is to the base image. Higher value means more noise and more different from the base image"
|
||||||
|
// if (disable) {
|
||||||
|
// toolTip = "BrushNet is enabled, Strength is disabled."
|
||||||
|
// }
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col gap-1">
|
<RowContainer>
|
||||||
<LabelTitle
|
<LabelTitle
|
||||||
text="Strength"
|
text="Strength"
|
||||||
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inpaint#strength"
|
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inpaint#strength"
|
||||||
toolTip="Strength is a measure of how much noise is added to the base image, which influences how similar the output is to the base image. Higher value means more noise and more different from the base image"
|
toolTip={toolTip}
|
||||||
|
// disabled={disable}
|
||||||
/>
|
/>
|
||||||
<RowContainer>
|
|
||||||
<Slider
|
<Slider
|
||||||
className="w-[180px]"
|
className="w-[110px]"
|
||||||
defaultValue={[100]}
|
defaultValue={[100]}
|
||||||
min={10}
|
min={10}
|
||||||
max={100}
|
max={100}
|
||||||
@ -462,18 +451,19 @@ const DiffusionOptions = () => {
|
|||||||
onValueChange={(vals) =>
|
onValueChange={(vals) =>
|
||||||
updateSettings({ sdStrength: vals[0] / 100 })
|
updateSettings({ sdStrength: vals[0] / 100 })
|
||||||
}
|
}
|
||||||
|
// disabled={disable}
|
||||||
/>
|
/>
|
||||||
<NumberInput
|
<NumberInput
|
||||||
id="strength"
|
id="strength"
|
||||||
className="w-[60px] rounded-full"
|
className="w-[50px] rounded-full"
|
||||||
numberValue={settings.sdStrength}
|
numberValue={settings.sdStrength}
|
||||||
allowFloat
|
allowFloat
|
||||||
onNumberValueChange={(val) => {
|
onNumberValueChange={(val) => {
|
||||||
updateSettings({ sdStrength: val })
|
updateSettings({ sdStrength: val })
|
||||||
}}
|
}}
|
||||||
|
// disabled={disable}
|
||||||
/>
|
/>
|
||||||
</RowContainer>
|
</RowContainer>
|
||||||
</div>
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -483,7 +473,7 @@ const DiffusionOptions = () => {
|
|||||||
}
|
}
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<div className="flex flex-col gap-4">
|
<div className="flex flex-col gap-2">
|
||||||
<RowContainer>
|
<RowContainer>
|
||||||
<LabelTitle
|
<LabelTitle
|
||||||
text="Extender"
|
text="Extender"
|
||||||
@ -560,10 +550,6 @@ const DiffusionOptions = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const renderPowerPaintTaskType = () => {
|
const renderPowerPaintTaskType = () => {
|
||||||
if (settings.model.name !== POWERPAINT) {
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<RowContainer>
|
<RowContainer>
|
||||||
<LabelTitle
|
<LabelTitle
|
||||||
@ -578,7 +564,7 @@ const DiffusionOptions = () => {
|
|||||||
}}
|
}}
|
||||||
disabled={settings.showExtender}
|
disabled={settings.showExtender}
|
||||||
>
|
>
|
||||||
<SelectTrigger className="w-[140px]">
|
<SelectTrigger className="w-[130px]">
|
||||||
<SelectValue placeholder="Select task" />
|
<SelectValue placeholder="Select task" />
|
||||||
</SelectTrigger>
|
</SelectTrigger>
|
||||||
<SelectContent align="end">
|
<SelectContent align="end">
|
||||||
@ -586,6 +572,7 @@ const DiffusionOptions = () => {
|
|||||||
{[
|
{[
|
||||||
PowerPaintTask.text_guided,
|
PowerPaintTask.text_guided,
|
||||||
PowerPaintTask.object_remove,
|
PowerPaintTask.object_remove,
|
||||||
|
PowerPaintTask.context_aware,
|
||||||
PowerPaintTask.shape_guided,
|
PowerPaintTask.shape_guided,
|
||||||
].map((task) => (
|
].map((task) => (
|
||||||
<SelectItem key={task} value={task}>
|
<SelectItem key={task} value={task}>
|
||||||
@ -599,17 +586,54 @@ const DiffusionOptions = () => {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const renderPowerPaintV1 = () => {
|
||||||
|
if (settings.model.name !== POWERPAINT) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
{renderPowerPaintTaskType()}
|
||||||
|
<Separator />
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const renderPowerPaintV2 = () => {
|
||||||
|
if (settings.model.support_powerpaint_v2 === false) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<RowContainer>
|
||||||
|
<LabelTitle
|
||||||
|
text="PowerPaint V2"
|
||||||
|
toolTip="PowerPaint is a plug-and-play image inpainting model works on any SD1.5 base models."
|
||||||
|
/>
|
||||||
|
<Switch
|
||||||
|
id="powerpaint-v2"
|
||||||
|
checked={settings.enablePowerPaintV2}
|
||||||
|
onCheckedChange={(value) => {
|
||||||
|
updateEnablePowerPaintV2(value)
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</RowContainer>
|
||||||
|
{renderPowerPaintTaskType()}
|
||||||
|
<Separator />
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
const renderSteps = () => {
|
const renderSteps = () => {
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col gap-1">
|
<RowContainer>
|
||||||
<LabelTitle
|
<LabelTitle
|
||||||
htmlFor="steps"
|
htmlFor="steps"
|
||||||
text="Steps"
|
text="Steps"
|
||||||
toolTip="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference."
|
toolTip="The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference."
|
||||||
/>
|
/>
|
||||||
<RowContainer>
|
|
||||||
<Slider
|
<Slider
|
||||||
className="w-[180px]"
|
className="w-[110px]"
|
||||||
defaultValue={[30]}
|
defaultValue={[30]}
|
||||||
min={1}
|
min={1}
|
||||||
max={100}
|
max={100}
|
||||||
@ -619,7 +643,7 @@ const DiffusionOptions = () => {
|
|||||||
/>
|
/>
|
||||||
<NumberInput
|
<NumberInput
|
||||||
id="steps"
|
id="steps"
|
||||||
className="w-[60px] rounded-full"
|
className="w-[50px] rounded-full"
|
||||||
numberValue={settings.sdSteps}
|
numberValue={settings.sdSteps}
|
||||||
allowFloat={false}
|
allowFloat={false}
|
||||||
onNumberValueChange={(val) => {
|
onNumberValueChange={(val) => {
|
||||||
@ -627,21 +651,19 @@ const DiffusionOptions = () => {
|
|||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</RowContainer>
|
</RowContainer>
|
||||||
</div>
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
const renderGuidanceScale = () => {
|
const renderGuidanceScale = () => {
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col gap-1">
|
<RowContainer>
|
||||||
<LabelTitle
|
<LabelTitle
|
||||||
text="Guidance scale"
|
text="Guidance"
|
||||||
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inpaint#guidance-scale"
|
url="https://huggingface.co/docs/diffusers/main/en/using-diffusers/inpaint#guidance-scale"
|
||||||
toolTip="Guidance scale affects how aligned the text prompt and generated image are. Higher value means the prompt and generated image are closely aligned, so the output is a stricter interpretation of the prompt"
|
toolTip="Guidance scale affects how aligned the text prompt and generated image are. Higher value means the prompt and generated image are closely aligned, so the output is a stricter interpretation of the prompt"
|
||||||
/>
|
/>
|
||||||
<RowContainer>
|
|
||||||
<Slider
|
<Slider
|
||||||
className="w-[180px]"
|
className="w-[110px]"
|
||||||
defaultValue={[750]}
|
defaultValue={[750]}
|
||||||
min={0}
|
min={0}
|
||||||
max={1500}
|
max={1500}
|
||||||
@ -652,8 +674,8 @@ const DiffusionOptions = () => {
|
|||||||
}
|
}
|
||||||
/>
|
/>
|
||||||
<NumberInput
|
<NumberInput
|
||||||
id="guidance-scale"
|
id="guid"
|
||||||
className="w-[60px] rounded-full"
|
className="w-[50px] rounded-full"
|
||||||
numberValue={settings.sdGuidanceScale}
|
numberValue={settings.sdGuidanceScale}
|
||||||
allowFloat
|
allowFloat
|
||||||
onNumberValueChange={(val) => {
|
onNumberValueChange={(val) => {
|
||||||
@ -661,7 +683,6 @@ const DiffusionOptions = () => {
|
|||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</RowContainer>
|
</RowContainer>
|
||||||
</div>
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -716,7 +737,7 @@ const DiffusionOptions = () => {
|
|||||||
/>
|
/>
|
||||||
<NumberInput
|
<NumberInput
|
||||||
id="seed"
|
id="seed"
|
||||||
className="w-[100px]"
|
className="w-[110px]"
|
||||||
disabled={!settings.seedFixed}
|
disabled={!settings.seedFixed}
|
||||||
numberValue={settings.seed}
|
numberValue={settings.seed}
|
||||||
allowFloat={false}
|
allowFloat={false}
|
||||||
@ -731,14 +752,14 @@ const DiffusionOptions = () => {
|
|||||||
|
|
||||||
const renderMaskBlur = () => {
|
const renderMaskBlur = () => {
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col gap-1">
|
<>
|
||||||
|
<RowContainer>
|
||||||
<LabelTitle
|
<LabelTitle
|
||||||
text="Mask blur"
|
text="Mask blur"
|
||||||
toolTip="How much to blur the mask before processing, in pixels. Make the generated inpainting boundaries appear more natural."
|
toolTip="How much to blur the mask before processing, in pixels. Make the generated inpainting boundaries appear more natural."
|
||||||
/>
|
/>
|
||||||
<RowContainer>
|
|
||||||
<Slider
|
<Slider
|
||||||
className="w-[180px]"
|
className="w-[110px]"
|
||||||
defaultValue={[settings.sdMaskBlur]}
|
defaultValue={[settings.sdMaskBlur]}
|
||||||
min={0}
|
min={0}
|
||||||
max={96}
|
max={96}
|
||||||
@ -748,7 +769,7 @@ const DiffusionOptions = () => {
|
|||||||
/>
|
/>
|
||||||
<NumberInput
|
<NumberInput
|
||||||
id="mask-blur"
|
id="mask-blur"
|
||||||
className="w-[60px] rounded-full"
|
className="w-[50px] rounded-full"
|
||||||
numberValue={settings.sdMaskBlur}
|
numberValue={settings.sdMaskBlur}
|
||||||
allowFloat={false}
|
allowFloat={false}
|
||||||
onNumberValueChange={(value) => {
|
onNumberValueChange={(value) => {
|
||||||
@ -756,7 +777,8 @@ const DiffusionOptions = () => {
|
|||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
</RowContainer>
|
</RowContainer>
|
||||||
</div>
|
<Separator />
|
||||||
|
</>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -785,15 +807,15 @@ const DiffusionOptions = () => {
|
|||||||
const renderMaskAdjuster = () => {
|
const renderMaskAdjuster = () => {
|
||||||
return (
|
return (
|
||||||
<>
|
<>
|
||||||
<div className="flex flex-col gap-1">
|
<div className="flex flex-col gap-2">
|
||||||
|
<RowContainer>
|
||||||
<LabelTitle
|
<LabelTitle
|
||||||
htmlFor="adjustMaskKernelSize"
|
htmlFor="adjustMaskKernelSize"
|
||||||
text="Adjust Mask"
|
text="Mask OP"
|
||||||
toolTip="Expand or shrink mask. Using the slider to adjust the kernel size for dilation or erosion."
|
toolTip="Expand or shrink mask. Using the slider to adjust the kernel size for dilation or erosion."
|
||||||
/>
|
/>
|
||||||
<RowContainer>
|
|
||||||
<Slider
|
<Slider
|
||||||
className="w-[180px]"
|
className="w-[110px]"
|
||||||
defaultValue={[12]}
|
defaultValue={[12]}
|
||||||
min={1}
|
min={1}
|
||||||
max={100}
|
max={100}
|
||||||
@ -805,7 +827,7 @@ const DiffusionOptions = () => {
|
|||||||
/>
|
/>
|
||||||
<NumberInput
|
<NumberInput
|
||||||
id="adjustMaskKernelSize"
|
id="adjustMaskKernelSize"
|
||||||
className="w-[60px] rounded-full"
|
className="w-[50px] rounded-full"
|
||||||
numberValue={settings.adjustMaskKernelSize}
|
numberValue={settings.adjustMaskKernelSize}
|
||||||
allowFloat={false}
|
allowFloat={false}
|
||||||
onNumberValueChange={(val) => {
|
onNumberValueChange={(val) => {
|
||||||
@ -815,7 +837,6 @@ const DiffusionOptions = () => {
|
|||||||
</RowContainer>
|
</RowContainer>
|
||||||
|
|
||||||
<RowContainer>
|
<RowContainer>
|
||||||
<div className="flex gap-1 justify-start">
|
|
||||||
<Button
|
<Button
|
||||||
variant="outline"
|
variant="outline"
|
||||||
className="p-1 h-8"
|
className="p-1 h-8"
|
||||||
@ -846,11 +867,8 @@ const DiffusionOptions = () => {
|
|||||||
onClick={() => adjustMask("reverse")}
|
onClick={() => adjustMask("reverse")}
|
||||||
disabled={isProcessing}
|
disabled={isProcessing}
|
||||||
>
|
>
|
||||||
<div className="flex items-center gap-1 select-none">
|
<div className="flex items-center gap-1 select-none">Reverse</div>
|
||||||
Reverse
|
|
||||||
</div>
|
|
||||||
</Button>
|
</Button>
|
||||||
</div>
|
|
||||||
|
|
||||||
<Button
|
<Button
|
||||||
variant="outline"
|
variant="outline"
|
||||||
@ -868,11 +886,13 @@ const DiffusionOptions = () => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div className="flex flex-col gap-4 mt-4">
|
<div className="flex flex-col gap-[14px] mt-4">
|
||||||
{renderCropper()}
|
{renderCropper()}
|
||||||
{renderExtender()}
|
{renderExtender()}
|
||||||
|
{renderMaskBlur()}
|
||||||
{renderMaskAdjuster()}
|
{renderMaskAdjuster()}
|
||||||
{renderPowerPaintTaskType()}
|
{renderMatchHistograms()}
|
||||||
|
{renderPowerPaintV1()}
|
||||||
{renderSteps()}
|
{renderSteps()}
|
||||||
{renderGuidanceScale()}
|
{renderGuidanceScale()}
|
||||||
{renderP2PImageGuidanceScale()}
|
{renderP2PImageGuidanceScale()}
|
||||||
@ -881,11 +901,10 @@ const DiffusionOptions = () => {
|
|||||||
{renderSeed()}
|
{renderSeed()}
|
||||||
{renderNegativePrompt()}
|
{renderNegativePrompt()}
|
||||||
<Separator />
|
<Separator />
|
||||||
|
{renderBrushNetSetting()}
|
||||||
|
{renderPowerPaintV2()}
|
||||||
{renderConterNetSetting()}
|
{renderConterNetSetting()}
|
||||||
{renderLCMLora()}
|
{renderLCMLora()}
|
||||||
{renderMaskBlur()}
|
|
||||||
{renderMatchHistograms()}
|
|
||||||
{renderFreeu()}
|
|
||||||
{renderPaintByExample()}
|
{renderPaintByExample()}
|
||||||
</div>
|
</div>
|
||||||
)
|
)
|
||||||
|
@ -38,7 +38,7 @@ const LDMOptions = () => {
|
|||||||
/>
|
/>
|
||||||
<NumberInput
|
<NumberInput
|
||||||
id="steps"
|
id="steps"
|
||||||
className="w-[60px] rounded-full"
|
className="w-[50px] rounded-full"
|
||||||
numberValue={settings.ldmSteps}
|
numberValue={settings.ldmSteps}
|
||||||
allowFloat={false}
|
allowFloat={false}
|
||||||
onNumberValueChange={(val) => {
|
onNumberValueChange={(val) => {
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
|
import { cn } from "@/lib/utils"
|
||||||
import { Button } from "../ui/button"
|
import { Button } from "../ui/button"
|
||||||
import { Label } from "../ui/label"
|
import { Label } from "../ui/label"
|
||||||
import { Tooltip, TooltipContent, TooltipTrigger } from "../ui/tooltip"
|
import { Tooltip, TooltipContent, TooltipTrigger } from "../ui/tooltip"
|
||||||
|
|
||||||
const RowContainer = ({ children }: { children: React.ReactNode }) => (
|
const RowContainer = ({ children }: { children: React.ReactNode }) => (
|
||||||
<div className="flex justify-between items-center pr-2">{children}</div>
|
<div className="flex justify-between items-center pr-4">{children}</div>
|
||||||
)
|
)
|
||||||
|
|
||||||
const LabelTitle = ({
|
const LabelTitle = ({
|
||||||
@ -12,19 +13,21 @@ const LabelTitle = ({
|
|||||||
url,
|
url,
|
||||||
htmlFor,
|
htmlFor,
|
||||||
disabled = false,
|
disabled = false,
|
||||||
|
className = "",
|
||||||
}: {
|
}: {
|
||||||
text: string
|
text: string
|
||||||
toolTip?: string
|
toolTip?: string
|
||||||
url?: string
|
url?: string
|
||||||
htmlFor?: string
|
htmlFor?: string
|
||||||
disabled?: boolean
|
disabled?: boolean
|
||||||
|
className?: string
|
||||||
}) => {
|
}) => {
|
||||||
return (
|
return (
|
||||||
<Tooltip>
|
<Tooltip>
|
||||||
<TooltipTrigger asChild>
|
<TooltipTrigger asChild>
|
||||||
<Label
|
<Label
|
||||||
htmlFor={htmlFor ? htmlFor : text.toLowerCase().replace(" ", "-")}
|
htmlFor={htmlFor ? htmlFor : text.toLowerCase().replace(" ", "-")}
|
||||||
className="font-medium"
|
className={cn("font-medium min-w-[65px]", className)}
|
||||||
disabled={disabled}
|
disabled={disabled}
|
||||||
>
|
>
|
||||||
{text}
|
{text}
|
||||||
|
@ -61,7 +61,7 @@ const SidePanel = () => {
|
|||||||
</SheetTrigger>
|
</SheetTrigger>
|
||||||
<SheetContent
|
<SheetContent
|
||||||
side="right"
|
side="right"
|
||||||
className="w-[300px] mt-[60px] outline-none pl-4 pr-1"
|
className="w-[286px] mt-[60px] outline-none pl-3 pr-1"
|
||||||
onOpenAutoFocus={(event) => event.preventDefault()}
|
onOpenAutoFocus={(event) => event.preventDefault()}
|
||||||
onPointerDownOutside={(event) => event.preventDefault()}
|
onPointerDownOutside={(event) => event.preventDefault()}
|
||||||
>
|
>
|
||||||
@ -85,10 +85,7 @@ const SidePanel = () => {
|
|||||||
</RowContainer>
|
</RowContainer>
|
||||||
<Separator />
|
<Separator />
|
||||||
</SheetHeader>
|
</SheetHeader>
|
||||||
<ScrollArea
|
<ScrollArea style={{ height: windowSize.height - 160 }}>
|
||||||
style={{ height: windowSize.height - 160 }}
|
|
||||||
className="pr-3"
|
|
||||||
>
|
|
||||||
{renderSidePanelOptions()}
|
{renderSidePanelOptions()}
|
||||||
</ScrollArea>
|
</ScrollArea>
|
||||||
</SheetContent>
|
</SheetContent>
|
||||||
|
@ -44,7 +44,10 @@ export interface NumberInputProps extends InputProps {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const NumberInput = React.forwardRef<HTMLInputElement, NumberInputProps>(
|
const NumberInput = React.forwardRef<HTMLInputElement, NumberInputProps>(
|
||||||
({ numberValue, allowFloat, onNumberValueChange, ...rest }, ref) => {
|
(
|
||||||
|
{ numberValue, allowFloat, onNumberValueChange, className, ...rest },
|
||||||
|
ref
|
||||||
|
) => {
|
||||||
const [value, setValue] = React.useState<string>(numberValue.toString())
|
const [value, setValue] = React.useState<string>(numberValue.toString())
|
||||||
|
|
||||||
React.useEffect(() => {
|
React.useEffect(() => {
|
||||||
@ -75,7 +78,15 @@ const NumberInput = React.forwardRef<HTMLInputElement, NumberInputProps>(
|
|||||||
setValue(val)
|
setValue(val)
|
||||||
}
|
}
|
||||||
|
|
||||||
return <Input ref={ref} value={value} onInput={onInput} {...rest} />
|
return (
|
||||||
|
<Input
|
||||||
|
ref={ref}
|
||||||
|
value={value}
|
||||||
|
onInput={onInput}
|
||||||
|
className={cn("text-center h-7 px-1", className)}
|
||||||
|
{...rest}
|
||||||
|
/>
|
||||||
|
)
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ const SelectTrigger = React.forwardRef<
|
|||||||
<SelectPrimitive.Trigger
|
<SelectPrimitive.Trigger
|
||||||
ref={ref}
|
ref={ref}
|
||||||
className={cn(
|
className={cn(
|
||||||
"flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent px-3 py-2 text-sm shadow-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-1 focus:ring-ring disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1",
|
"flex h-9 w-full items-center justify-between whitespace-nowrap rounded-md border border-input bg-transparent pl-2 pr-1 py-2 text-sm shadow-sm ring-offset-background placeholder:text-muted-foreground focus:outline-none focus:ring-1 focus:ring-ring disabled:cursor-not-allowed disabled:opacity-50 [&>span]:line-clamp-1",
|
||||||
className
|
className
|
||||||
)}
|
)}
|
||||||
tabIndex={-1}
|
tabIndex={-1}
|
||||||
|
@ -68,8 +68,6 @@ export default async function inpaint(
|
|||||||
sd_sampler: settings.sdSampler,
|
sd_sampler: settings.sdSampler,
|
||||||
sd_seed: settings.seedFixed ? settings.seed : -1,
|
sd_seed: settings.seedFixed ? settings.seed : -1,
|
||||||
sd_match_histograms: settings.sdMatchHistograms,
|
sd_match_histograms: settings.sdMatchHistograms,
|
||||||
sd_freeu: settings.enableFreeu,
|
|
||||||
sd_freeu_config: settings.freeuConfig,
|
|
||||||
sd_lcm_lora: settings.enableLCMLora,
|
sd_lcm_lora: settings.enableLCMLora,
|
||||||
paint_by_example_example_image: exampleImageBase64,
|
paint_by_example_example_image: exampleImageBase64,
|
||||||
p2p_image_guidance_scale: settings.p2pImageGuidanceScale,
|
p2p_image_guidance_scale: settings.p2pImageGuidanceScale,
|
||||||
@ -78,6 +76,10 @@ export default async function inpaint(
|
|||||||
controlnet_method: settings.controlnetMethod
|
controlnet_method: settings.controlnetMethod
|
||||||
? settings.controlnetMethod
|
? settings.controlnetMethod
|
||||||
: "",
|
: "",
|
||||||
|
enable_brushnet: settings.enableBrushNet,
|
||||||
|
brushnet_method: settings.brushnetMethod ? settings.brushnetMethod : "",
|
||||||
|
brushnet_conditioning_scale: settings.brushnetConditioningScale,
|
||||||
|
enable_powerpaint_v2: settings.enablePowerPaintV2,
|
||||||
powerpaint_task: settings.showExtender
|
powerpaint_task: settings.showExtender
|
||||||
? PowerPaintTask.outpainting
|
? PowerPaintTask.outpainting
|
||||||
: settings.powerpaintTask,
|
: settings.powerpaintTask,
|
||||||
|
@ -7,7 +7,6 @@ import {
|
|||||||
AdjustMaskOperate,
|
AdjustMaskOperate,
|
||||||
CV2Flag,
|
CV2Flag,
|
||||||
ExtenderDirection,
|
ExtenderDirection,
|
||||||
FreeuConfig,
|
|
||||||
LDMSampler,
|
LDMSampler,
|
||||||
Line,
|
Line,
|
||||||
LineGroup,
|
LineGroup,
|
||||||
@ -99,11 +98,15 @@ export type Settings = {
|
|||||||
controlnetConditioningScale: number
|
controlnetConditioningScale: number
|
||||||
controlnetMethod: string
|
controlnetMethod: string
|
||||||
|
|
||||||
|
// BrushNet
|
||||||
|
enableBrushNet: boolean
|
||||||
|
brushnetMethod: string
|
||||||
|
brushnetConditioningScale: number
|
||||||
|
|
||||||
enableLCMLora: boolean
|
enableLCMLora: boolean
|
||||||
enableFreeu: boolean
|
|
||||||
freeuConfig: FreeuConfig
|
|
||||||
|
|
||||||
// PowerPaint
|
// PowerPaint
|
||||||
|
enablePowerPaintV2: boolean
|
||||||
powerpaintTask: PowerPaintTask
|
powerpaintTask: PowerPaintTask
|
||||||
|
|
||||||
// AdjustMask
|
// AdjustMask
|
||||||
@ -192,6 +195,13 @@ type AppAction = {
|
|||||||
setServerConfig: (newValue: ServerConfig) => void
|
setServerConfig: (newValue: ServerConfig) => void
|
||||||
setSeed: (newValue: number) => void
|
setSeed: (newValue: number) => void
|
||||||
updateSettings: (newSettings: Partial<Settings>) => void
|
updateSettings: (newSettings: Partial<Settings>) => void
|
||||||
|
|
||||||
|
// 互斥
|
||||||
|
updateEnablePowerPaintV2: (newValue: boolean) => void
|
||||||
|
updateEnableBrushNet: (newValue: boolean) => void
|
||||||
|
updateEnableControlnet: (newValue: boolean) => void
|
||||||
|
updateLCMLora: (newValue: boolean) => void
|
||||||
|
|
||||||
setModel: (newModel: ModelInfo) => void
|
setModel: (newModel: ModelInfo) => void
|
||||||
updateFileManagerState: (newState: Partial<FileManagerState>) => void
|
updateFileManagerState: (newState: Partial<FileManagerState>) => void
|
||||||
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
|
updateInteractiveSegState: (newState: Partial<InteractiveSegState>) => void
|
||||||
@ -306,15 +316,16 @@ const defaultValues: AppState = {
|
|||||||
path: "lama",
|
path: "lama",
|
||||||
model_type: "inpaint",
|
model_type: "inpaint",
|
||||||
support_controlnet: false,
|
support_controlnet: false,
|
||||||
|
support_brushnet: false,
|
||||||
support_strength: false,
|
support_strength: false,
|
||||||
support_outpainting: false,
|
support_outpainting: false,
|
||||||
|
support_powerpaint_v2: false,
|
||||||
controlnets: [],
|
controlnets: [],
|
||||||
support_freeu: false,
|
brushnets: [],
|
||||||
support_lcm_lora: false,
|
support_lcm_lora: false,
|
||||||
is_single_file_diffusers: false,
|
is_single_file_diffusers: false,
|
||||||
need_prompt: false,
|
need_prompt: false,
|
||||||
},
|
},
|
||||||
enableControlnet: false,
|
|
||||||
showCropper: false,
|
showCropper: false,
|
||||||
showExtender: false,
|
showExtender: false,
|
||||||
extenderDirection: ExtenderDirection.xy,
|
extenderDirection: ExtenderDirection.xy,
|
||||||
@ -339,11 +350,14 @@ const defaultValues: AppState = {
|
|||||||
sdMatchHistograms: false,
|
sdMatchHistograms: false,
|
||||||
sdScale: 1.0,
|
sdScale: 1.0,
|
||||||
p2pImageGuidanceScale: 1.5,
|
p2pImageGuidanceScale: 1.5,
|
||||||
controlnetConditioningScale: 0.4,
|
enableControlnet: false,
|
||||||
controlnetMethod: "lllyasviel/control_v11p_sd15_canny",
|
controlnetMethod: "lllyasviel/control_v11p_sd15_canny",
|
||||||
|
controlnetConditioningScale: 0.4,
|
||||||
|
enableBrushNet: false,
|
||||||
|
brushnetMethod: "random_mask",
|
||||||
|
brushnetConditioningScale: 1.0,
|
||||||
enableLCMLora: false,
|
enableLCMLora: false,
|
||||||
enableFreeu: false,
|
enablePowerPaintV2: false,
|
||||||
freeuConfig: { s1: 0.9, s2: 0.2, b1: 1.2, b2: 1.4 },
|
|
||||||
powerpaintTask: PowerPaintTask.text_guided,
|
powerpaintTask: PowerPaintTask.text_guided,
|
||||||
adjustMaskKernelSize: 12,
|
adjustMaskKernelSize: 12,
|
||||||
},
|
},
|
||||||
@ -421,6 +435,8 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
if (
|
if (
|
||||||
get().settings.model.support_outpainting &&
|
get().settings.model.support_outpainting &&
|
||||||
settings.showExtender &&
|
settings.showExtender &&
|
||||||
|
extenderState.x === 0 &&
|
||||||
|
extenderState.y === 0 &&
|
||||||
extenderState.height === imageHeight &&
|
extenderState.height === imageHeight &&
|
||||||
extenderState.width === imageWidth
|
extenderState.width === imageWidth
|
||||||
) {
|
) {
|
||||||
@ -794,6 +810,48 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
})
|
})
|
||||||
},
|
},
|
||||||
|
|
||||||
|
updateEnablePowerPaintV2: (newValue: boolean) => {
|
||||||
|
get().updateSettings({ enablePowerPaintV2: newValue })
|
||||||
|
if (newValue) {
|
||||||
|
get().updateSettings({
|
||||||
|
enableBrushNet: false,
|
||||||
|
enableControlnet: false,
|
||||||
|
enableLCMLora: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
updateEnableBrushNet: (newValue: boolean) => {
|
||||||
|
get().updateSettings({ enableBrushNet: newValue })
|
||||||
|
if (newValue) {
|
||||||
|
get().updateSettings({
|
||||||
|
enablePowerPaintV2: false,
|
||||||
|
enableControlnet: false,
|
||||||
|
enableLCMLora: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
updateEnableControlnet(newValue) {
|
||||||
|
get().updateSettings({ enableControlnet: newValue })
|
||||||
|
if (newValue) {
|
||||||
|
get().updateSettings({
|
||||||
|
enablePowerPaintV2: false,
|
||||||
|
enableBrushNet: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
updateLCMLora(newValue) {
|
||||||
|
get().updateSettings({ enableLCMLora: newValue })
|
||||||
|
if (newValue) {
|
||||||
|
get().updateSettings({
|
||||||
|
enablePowerPaintV2: false,
|
||||||
|
enableBrushNet: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
setModel: (newModel: ModelInfo) => {
|
setModel: (newModel: ModelInfo) => {
|
||||||
set((state) => {
|
set((state) => {
|
||||||
state.settings.model = newModel
|
state.settings.model = newModel
|
||||||
@ -1076,7 +1134,7 @@ export const useStore = createWithEqualityFn<AppState & AppAction>()(
|
|||||||
})),
|
})),
|
||||||
{
|
{
|
||||||
name: "ZUSTAND_STATE", // name of the item in the storage (must be unique)
|
name: "ZUSTAND_STATE", // name of the item in the storage (must be unique)
|
||||||
version: 1,
|
version: 2,
|
||||||
partialize: (state) =>
|
partialize: (state) =>
|
||||||
Object.fromEntries(
|
Object.fromEntries(
|
||||||
Object.entries(state).filter(([key]) =>
|
Object.entries(state).filter(([key]) =>
|
||||||
|
@ -48,8 +48,10 @@ export interface ModelInfo {
|
|||||||
support_strength: boolean
|
support_strength: boolean
|
||||||
support_outpainting: boolean
|
support_outpainting: boolean
|
||||||
support_controlnet: boolean
|
support_controlnet: boolean
|
||||||
|
support_brushnet: boolean
|
||||||
|
support_powerpaint_v2: boolean
|
||||||
controlnets: string[]
|
controlnets: string[]
|
||||||
support_freeu: boolean
|
brushnets: string[]
|
||||||
support_lcm_lora: boolean
|
support_lcm_lora: boolean
|
||||||
need_prompt: boolean
|
need_prompt: boolean
|
||||||
is_single_file_diffusers: boolean
|
is_single_file_diffusers: boolean
|
||||||
@ -96,13 +98,6 @@ export interface Rect {
|
|||||||
height: number
|
height: number
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface FreeuConfig {
|
|
||||||
s1: number
|
|
||||||
s2: number
|
|
||||||
b1: number
|
|
||||||
b2: number
|
|
||||||
}
|
|
||||||
|
|
||||||
export interface Point {
|
export interface Point {
|
||||||
x: number
|
x: number
|
||||||
y: number
|
y: number
|
||||||
@ -129,6 +124,7 @@ export enum ExtenderDirection {
|
|||||||
export enum PowerPaintTask {
|
export enum PowerPaintTask {
|
||||||
text_guided = "text-guided",
|
text_guided = "text-guided",
|
||||||
shape_guided = "shape-guided",
|
shape_guided = "shape-guided",
|
||||||
|
context_aware = "context-aware",
|
||||||
object_remove = "object-remove",
|
object_remove = "object-remove",
|
||||||
outpainting = "outpainting",
|
outpainting = "outpainting",
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user