add powerpaint v2

This commit is contained in:
Qing 2024-04-24 20:22:29 +08:00
parent ccea072dc5
commit 911f7224b6
14 changed files with 8082 additions and 2318 deletions

View File

@ -13,7 +13,7 @@ from iopaint.helper import (
switch_mps_device, switch_mps_device,
) )
from iopaint.schema import InpaintRequest, HDStrategy, SDSampler from iopaint.schema import InpaintRequest, HDStrategy, SDSampler
from .helper.g_diffuser_bot import expand_image, expand_image2 from .helper.g_diffuser_bot import expand_image
from .utils import get_scheduler from .utils import get_scheduler
@ -35,8 +35,7 @@ class InpaintModel:
self.init_model(device, **kwargs) self.init_model(device, **kwargs)
@abc.abstractmethod @abc.abstractmethod
def init_model(self, device, **kwargs): def init_model(self, device, **kwargs): ...
...
@staticmethod @staticmethod
@abc.abstractmethod @abc.abstractmethod
@ -53,8 +52,7 @@ class InpaintModel:
... ...
@staticmethod @staticmethod
def download(): def download(): ...
...
def _pad_forward(self, image, mask, config: InpaintRequest): def _pad_forward(self, image, mask, config: InpaintRequest):
origin_height, origin_width = image.shape[:2] origin_height, origin_width = image.shape[:2]
@ -96,7 +94,7 @@ class InpaintModel:
# logger.info(f"hd_strategy: {config.hd_strategy}") # logger.info(f"hd_strategy: {config.hd_strategy}")
if config.hd_strategy == HDStrategy.CROP: if config.hd_strategy == HDStrategy.CROP:
if max(image.shape) > config.hd_strategy_crop_trigger_size: if max(image.shape) > config.hd_strategy_crop_trigger_size:
logger.info(f"Run crop strategy") logger.info("Run crop strategy")
boxes = boxes_from_mask(mask) boxes = boxes_from_mask(mask)
crop_result = [] crop_result = []
for box in boxes: for box in boxes:
@ -327,14 +325,12 @@ class DiffusionInpaintModel(InpaintModel):
padding_r = max(0, cropper_r - image_r) padding_r = max(0, cropper_r - image_r)
padding_b = max(0, cropper_b - image_b) padding_b = max(0, cropper_b - image_b)
expanded_image, mask_image = expand_image2( expanded_image, mask_image = expand_image(
cropped_image, cropped_image,
left=padding_l, left=padding_l,
top=padding_t, top=padding_t,
right=padding_r, right=padding_r,
bottom=padding_b, bottom=padding_b,
softness=config.sd_outpainting_softness,
space=config.sd_outpainting_space,
) )
# 最终扩大了的 image, BGR # 最终扩大了的 image, BGR
@ -381,15 +377,6 @@ class DiffusionInpaintModel(InpaintModel):
interpolation=cv2.INTER_CUBIC, interpolation=cv2.INTER_CUBIC,
) )
# blend result, copy from g_diffuser_bot
# mask_rgb = 1.0 - np_img_grey_to_rgb(mask / 255.0)
# inpaint_result = np.clip(
# inpaint_result * (1.0 - mask_rgb) + image * mask_rgb, 0.0, 255.0
# )
# original_pixel_indices = mask < 127
# inpaint_result[original_pixel_indices] = image[:, :, ::-1][
# original_pixel_indices
# ]
return inpaint_result return inpaint_result
def set_scheduler(self, config: InpaintRequest): def set_scheduler(self, config: InpaintRequest):
@ -412,7 +399,7 @@ class DiffusionInpaintModel(InpaintModel):
if config.sd_match_histograms: if config.sd_match_histograms:
result = self._match_histograms(result, image[:, :, ::-1], mask) result = self._match_histograms(result, image[:, :, ::-1], mask)
# if config.sd_mask_blur != 0: if config.use_extender and config.sd_mask_blur != 0:
# k = 2 * config.sd_mask_blur + 1 k = 2 * config.sd_mask_blur + 1
# mask = cv2.GaussianBlur(mask, (k, k), 0) mask = cv2.GaussianBlur(mask, (k, k), 0)
return result, image, mask return result, image, mask

View File

@ -1,174 +1,29 @@
# code copy from: https://github.com/parlance-zz/g-diffuser-bot
import cv2 import cv2
import numpy as np import numpy as np
def np_img_grey_to_rgb(data): def expand_image(cv2_img, top: int, right: int, bottom: int, left: int):
if data.ndim == 3:
return data
return np.expand_dims(data, 2) * np.ones((1, 1, 3))
def convolve(data1, data2): # fast convolution with fft
if data1.ndim != data2.ndim: # promote to rgb if mismatch
if data1.ndim < 3:
data1 = np_img_grey_to_rgb(data1)
if data2.ndim < 3:
data2 = np_img_grey_to_rgb(data2)
return ifft2(fft2(data1) * fft2(data2))
def fft2(data):
if data.ndim > 2: # multiple channels
out_fft = np.zeros(
(data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128
)
for c in range(data.shape[2]):
c_data = data[:, :, c]
out_fft[:, :, c] = np.fft.fft2(np.fft.fftshift(c_data), norm="ortho")
out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c])
else: # single channel
out_fft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
out_fft[:, :] = np.fft.fft2(np.fft.fftshift(data), norm="ortho")
out_fft[:, :] = np.fft.ifftshift(out_fft[:, :])
return out_fft
def ifft2(data):
if data.ndim > 2: # multiple channels
out_ifft = np.zeros(
(data.shape[0], data.shape[1], data.shape[2]), dtype=np.complex128
)
for c in range(data.shape[2]):
c_data = data[:, :, c]
out_ifft[:, :, c] = np.fft.ifft2(np.fft.fftshift(c_data), norm="ortho")
out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c])
else: # single channel
out_ifft = np.zeros((data.shape[0], data.shape[1]), dtype=np.complex128)
out_ifft[:, :] = np.fft.ifft2(np.fft.fftshift(data), norm="ortho")
out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :])
return out_ifft
def get_gradient_kernel(width, height, std=3.14, mode="linear"):
window_scale_x = float(
width / min(width, height)
) # for non-square aspect ratios we still want a circular kernel
window_scale_y = float(height / min(width, height))
if mode == "gaussian":
x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
kx = np.exp(-x * x * std)
if window_scale_x != window_scale_y:
y = (np.arange(height) / height * 2.0 - 1.0) * window_scale_y
ky = np.exp(-y * y * std)
else:
y = x
ky = kx
return np.outer(kx, ky)
elif mode == "linear":
x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x
if window_scale_x != window_scale_y:
y = (np.arange(height) / height * 2.0 - 1.0) * window_scale_y
else:
y = x
return np.clip(1.0 - np.sqrt(np.add.outer(x * x, y * y)) * std / 3.14, 0.0, 1.0)
else:
raise Exception("Error: Unknown mode in get_gradient_kernel: {0}".format(mode))
def image_blur(data, std=3.14, mode="linear"):
width = data.shape[0]
height = data.shape[1]
kernel = get_gradient_kernel(width, height, std, mode=mode)
return np.real(convolve(data, kernel / np.sqrt(np.sum(kernel * kernel))))
def soften_mask(mask_img, softness, space):
if softness == 0:
return mask_img
softness = min(softness, 1.0)
space = np.clip(space, 0.0, 1.0)
original_max_opacity = np.max(mask_img)
out_mask = mask_img <= 0.0
blurred_mask = image_blur(mask_img, 3.5 / softness, mode="linear")
blurred_mask = np.maximum(blurred_mask - np.max(blurred_mask[out_mask]), 0.0)
mask_img *= blurred_mask # preserve partial opacity in original input mask
mask_img /= np.max(mask_img) # renormalize
mask_img = np.clip(mask_img - space, 0.0, 1.0) # make space
mask_img /= np.max(mask_img) # and renormalize again
mask_img *= original_max_opacity # restore original max opacity
return mask_img
def expand_image(
cv2_img, top: int, right: int, bottom: int, left: int, softness: float, space: float
):
assert cv2_img.shape[2] == 3 assert cv2_img.shape[2] == 3
origin_h, origin_w = cv2_img.shape[:2] origin_h, origin_w = cv2_img.shape[:2]
new_width = cv2_img.shape[1] + left + right
new_height = cv2_img.shape[0] + top + bottom
# TODO: which is better? # TODO: which is better?
# new_img = np.random.randint(0, 255, (new_height, new_width, 3), np.uint8) # new_img = np.ones((new_height, new_width, 3), np.uint8) * 255
new_img = cv2.copyMakeBorder(
cv2_img, top, bottom, left, right, cv2.BORDER_REPLICATE
)
mask_img = np.zeros((new_height, new_width), np.uint8)
mask_img[top: top + cv2_img.shape[0], left: left + cv2_img.shape[1]] = 255
if softness > 0.0:
mask_img = soften_mask(mask_img / 255.0, softness / 100.0, space / 100.0)
mask_img = (np.clip(mask_img, 0.0, 1.0) * 255.0).astype(np.uint8)
mask_image = 255.0 - mask_img # extract mask from alpha channel and invert
rgb_init_image = (
0.0 + new_img[:, :, 0:3]
) # strip mask from init_img leaving only rgb channels
hard_mask = np.zeros_like(cv2_img[:, :, 0])
if top != 0:
hard_mask[0: origin_h // 2, :] = 255
if bottom != 0:
hard_mask[origin_h // 2:, :] = 255
if left != 0:
hard_mask[:, 0: origin_w // 2] = 255
if right != 0:
hard_mask[:, origin_w // 2:] = 255
hard_mask = cv2.copyMakeBorder(
hard_mask, top, bottom, left, right, cv2.BORDER_CONSTANT, value=255
)
mask_image = np.where(hard_mask > 0, mask_image, 0)
return rgb_init_image.astype(np.uint8), mask_image.astype(np.uint8)
def expand_image2(
cv2_img, top: int, right: int, bottom: int, left: int, softness: float, space: float
):
assert cv2_img.shape[2] == 3
origin_h, origin_w = cv2_img.shape[:2]
new_width = cv2_img.shape[1] + left + right
new_height = cv2_img.shape[0] + top + bottom
# TODO: which is better?
# new_img = np.random.randint(0, 255, (new_height, new_width, 3), np.uint8)
new_img = cv2.copyMakeBorder( new_img = cv2.copyMakeBorder(
cv2_img, top, bottom, left, right, cv2.BORDER_REPLICATE cv2_img, top, bottom, left, right, cv2.BORDER_REPLICATE
) )
inner_padding_left = 13 if left > 0 else 0 inner_padding_left = 0 if left > 0 else 0
inner_padding_right = 13 if right > 0 else 0 inner_padding_right = 0 if right > 0 else 0
inner_padding_top = 13 if top > 0 else 0 inner_padding_top = 0 if top > 0 else 0
inner_padding_bottom = 13 if bottom > 0 else 0 inner_padding_bottom = 0 if bottom > 0 else 0
mask_image = np.zeros( mask_image = np.zeros(
( (
origin_h - inner_padding_top - inner_padding_bottom origin_h - inner_padding_top - inner_padding_bottom,
, origin_w - inner_padding_left - inner_padding_right origin_w - inner_padding_left - inner_padding_right,
), ),
np.uint8) np.uint8,
)
mask_image = cv2.copyMakeBorder( mask_image = cv2.copyMakeBorder(
mask_image, mask_image,
top + inner_padding_top, top + inner_padding_top,
@ -176,11 +31,11 @@ def expand_image2(
left + inner_padding_left, left + inner_padding_left,
right + inner_padding_right, right + inner_padding_right,
cv2.BORDER_CONSTANT, cv2.BORDER_CONSTANT,
value=255 value=255,
) )
# k = 2*int(min(origin_h, origin_w) // 6)+1 # k = 2*int(min(origin_h, origin_w) // 6)+1
k = 7 # k = 7
mask_image = cv2.GaussianBlur(mask_image, (k, k), 0) # mask_image = cv2.GaussianBlur(mask_image, (k, k), 0)
return new_img, mask_image return new_img, mask_image
@ -190,7 +45,7 @@ if __name__ == "__main__":
current_dir = Path(__file__).parent.absolute().resolve() current_dir = Path(__file__).parent.absolute().resolve()
image_path = "/Users/cwq/code/github/IOPaint/iopaint/tests/bunny.jpeg" image_path = "/Users/cwq/code/github/IOPaint/iopaint/tests/bunny.jpeg"
init_image = cv2.imread(str(image_path)) init_image = cv2.imread(str(image_path))
init_image, mask_image = expand_image2( init_image, mask_image = expand_image(
init_image, init_image,
top=0, top=0,
right=0, right=0,

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,141 @@
import PIL.Image
import cv2
import torch
from loguru import logger
from transformers import CLIPTextModel, CLIPTokenizer
import numpy as np
from ..base import DiffusionInpaintModel
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
from ..utils import (
get_torch_dtype,
enable_low_mem,
is_local_files_only,
handle_from_pretrained_exceptions,
)
from .powerpaint_tokenizer import task_to_prompt
from iopaint.schema import InpaintRequest
from .v2.BrushNet_CA import BrushNetModel
from .v2.unet_2d_condition import UNet2DConditionModel
class PowerPaintV2(DiffusionInpaintModel):
pad_mod = 8
min_size = 512
lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
hf_model_id = "Sanster/PowerPaint_v2"
def init_model(self, device: torch.device, **kwargs):
from .v2.pipeline_PowerPaint_Brushnet_CA import (
StableDiffusionPowerPaintBrushNetPipeline,
)
from .powerpaint_tokenizer import PowerPaintTokenizer
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
model_kwargs = {"local_files_only": is_local_files_only(**kwargs)}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(
dict(
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
)
text_encoder_brushnet = CLIPTextModel.from_pretrained(
self.hf_model_id,
subfolder="text_encoder_brushnet",
variant="fp16",
torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"],
)
unet = handle_from_pretrained_exceptions(
UNet2DConditionModel.from_pretrained,
pretrained_model_name_or_path=self.model_id_or_path,
subfolder="unet",
variant="fp16",
torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"],
)
brushnet = BrushNetModel.from_pretrained(
self.hf_model_id,
subfolder="PowerPaint_Brushnet",
variant="fp16",
torch_dtype=torch_dtype,
local_files_only=model_kwargs["local_files_only"],
)
pipe = handle_from_pretrained_exceptions(
StableDiffusionPowerPaintBrushNetPipeline.from_pretrained,
pretrained_model_name_or_path=self.model_id_or_path,
torch_dtype=torch_dtype,
unet=unet,
brushnet=brushnet,
text_encoder_brushnet=text_encoder_brushnet,
variant="fp16",
**model_kwargs,
)
pipe.tokenizer = PowerPaintTokenizer(
CLIPTokenizer.from_pretrained(self.hf_model_id, subfolder="tokenizer")
)
self.model = pipe
enable_low_mem(self.model, kwargs.get("low_mem", False))
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
if kwargs["sd_cpu_textencoder"]:
logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = CPUTextEncoderWrapper(
self.model.text_encoder, torch_dtype
)
self.callback = kwargs.pop("callback", None)
def forward(self, image, mask, config: InpaintRequest):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
self.set_scheduler(config)
image = image * (1 - mask / 255.0)
img_h, img_w = image.shape[:2]
image = PIL.Image.fromarray(image.astype(np.uint8))
mask = PIL.Image.fromarray(mask[:, :, -1], mode="L").convert("RGB")
promptA, promptB, negative_promptA, negative_promptB = task_to_prompt(
config.powerpaint_task
)
output = self.model(
image=image,
mask=mask,
promptA=promptA,
promptB=promptB,
promptU=config.prompt,
tradoff=config.fitting_degree,
tradoff_nag=config.fitting_degree,
negative_promptA=negative_promptA,
negative_promptB=negative_promptB,
negative_promptU=config.negative_prompt,
num_inference_steps=config.sd_steps,
# strength=config.sd_strength,
brushnet_conditioning_scale=1.0,
guidance_scale=config.sd_guidance_scale,
output_type="np",
callback=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
callback_steps=1,
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output

View File

@ -1,8 +1,6 @@
import torch
import torch.nn as nn
import copy import copy
import random import random
from typing import Any, List, Optional, Union from typing import Any, List, Union
from transformers import CLIPTokenizer from transformers import CLIPTokenizer
from iopaint.schema import PowerPaintTask from iopaint.schema import PowerPaintTask
@ -14,6 +12,11 @@ def add_task_to_prompt(prompt, negative_prompt, task: PowerPaintTask):
promptB = prompt + " P_ctxt" promptB = prompt + " P_ctxt"
negative_promptA = negative_prompt + " P_obj" negative_promptA = negative_prompt + " P_obj"
negative_promptB = negative_prompt + " P_obj" negative_promptB = negative_prompt + " P_obj"
elif task == PowerPaintTask.context_aware:
promptA = prompt + " P_ctxt"
promptB = prompt + " P_ctxt"
negative_promptA = negative_prompt
negative_promptB = negative_prompt
elif task == PowerPaintTask.shape_guided: elif task == PowerPaintTask.shape_guided:
promptA = prompt + " P_shape" promptA = prompt + " P_shape"
promptB = prompt + " P_ctxt" promptB = prompt + " P_ctxt"
@ -33,6 +36,18 @@ def add_task_to_prompt(prompt, negative_prompt, task: PowerPaintTask):
return promptA, promptB, negative_promptA, negative_promptB return promptA, promptB, negative_promptA, negative_promptB
def task_to_prompt(task: PowerPaintTask):
promptA, promptB, negative_promptA, negative_promptB = add_task_to_prompt(
"", "", task
)
return (
promptA.strip(),
promptB.strip(),
negative_promptA.strip(),
negative_promptB.strip(),
)
class PowerPaintTokenizer: class PowerPaintTokenizer:
def __init__(self, tokenizer: CLIPTokenizer): def __init__(self, tokenizer: CLIPTokenizer):
self.wrapped = tokenizer self.wrapped = tokenizer
@ -237,304 +252,3 @@ class PowerPaintTokenizer:
return text return text
replaced_text = self.replace_text_with_placeholder_tokens(text) replaced_text = self.replace_text_with_placeholder_tokens(text)
return replaced_text return replaced_text
class EmbeddingLayerWithFixes(nn.Module):
"""The revised embedding layer to support external embeddings. This design
of this class is inspired by https://github.com/AUTOMATIC1111/stable-
diffusion-webui/blob/22bcc7be428c94e9408f589966c2040187245d81/modules/sd_hi
jack.py#L224 # noqa.
Args:
wrapped (nn.Emebdding): The embedding layer to be wrapped.
external_embeddings (Union[dict, List[dict]], optional): The external
embeddings added to this layer. Defaults to None.
"""
def __init__(
self,
wrapped: nn.Embedding,
external_embeddings: Optional[Union[dict, List[dict]]] = None,
):
super().__init__()
self.wrapped = wrapped
self.num_embeddings = wrapped.weight.shape[0]
self.external_embeddings = []
if external_embeddings:
self.add_embeddings(external_embeddings)
self.trainable_embeddings = nn.ParameterDict()
@property
def weight(self):
"""Get the weight of wrapped embedding layer."""
return self.wrapped.weight
def check_duplicate_names(self, embeddings: List[dict]):
"""Check whether duplicate names exist in list of 'external
embeddings'.
Args:
embeddings (List[dict]): A list of embedding to be check.
"""
names = [emb["name"] for emb in embeddings]
assert len(names) == len(set(names)), (
"Found duplicated names in 'external_embeddings'. Name list: " f"'{names}'"
)
def check_ids_overlap(self, embeddings):
"""Check whether overlap exist in token ids of 'external_embeddings'.
Args:
embeddings (List[dict]): A list of embedding to be check.
"""
ids_range = [[emb["start"], emb["end"], emb["name"]] for emb in embeddings]
ids_range.sort() # sort by 'start'
# check if 'end' has overlapping
for idx in range(len(ids_range) - 1):
name1, name2 = ids_range[idx][-1], ids_range[idx + 1][-1]
assert ids_range[idx][1] <= ids_range[idx + 1][0], (
f"Found ids overlapping between embeddings '{name1}' " f"and '{name2}'."
)
def add_embeddings(self, embeddings: Optional[Union[dict, List[dict]]]):
"""Add external embeddings to this layer.
Use case:
>>> 1. Add token to tokenizer and get the token id.
>>> tokenizer = TokenizerWrapper('openai/clip-vit-base-patch32')
>>> # 'how much' in kiswahili
>>> tokenizer.add_placeholder_tokens('ngapi', num_vec_per_token=4)
>>>
>>> 2. Add external embeddings to the model.
>>> new_embedding = {
>>> 'name': 'ngapi', # 'how much' in kiswahili
>>> 'embedding': torch.ones(1, 15) * 4,
>>> 'start': tokenizer.get_token_info('kwaheri')['start'],
>>> 'end': tokenizer.get_token_info('kwaheri')['end'],
>>> 'trainable': False # if True, will registry as a parameter
>>> }
>>> embedding_layer = nn.Embedding(10, 15)
>>> embedding_layer_wrapper = EmbeddingLayerWithFixes(embedding_layer)
>>> embedding_layer_wrapper.add_embeddings(new_embedding)
>>>
>>> 3. Forward tokenizer and embedding layer!
>>> input_text = ['hello, ngapi!', 'hello my friend, ngapi?']
>>> input_ids = tokenizer(
>>> input_text, padding='max_length', truncation=True,
>>> return_tensors='pt')['input_ids']
>>> out_feat = embedding_layer_wrapper(input_ids)
>>>
>>> 4. Let's validate the result!
>>> assert (out_feat[0, 3: 7] == 2.3).all()
>>> assert (out_feat[2, 5: 9] == 2.3).all()
Args:
embeddings (Union[dict, list[dict]]): The external embeddings to
be added. Each dict must contain the following 4 fields: 'name'
(the name of this embedding), 'embedding' (the embedding
tensor), 'start' (the start token id of this embedding), 'end'
(the end token id of this embedding). For example:
`{name: NAME, start: START, end: END, embedding: torch.Tensor}`
"""
if isinstance(embeddings, dict):
embeddings = [embeddings]
self.external_embeddings += embeddings
self.check_duplicate_names(self.external_embeddings)
self.check_ids_overlap(self.external_embeddings)
# set for trainable
added_trainable_emb_info = []
for embedding in embeddings:
trainable = embedding.get("trainable", False)
if trainable:
name = embedding["name"]
embedding["embedding"] = torch.nn.Parameter(embedding["embedding"])
self.trainable_embeddings[name] = embedding["embedding"]
added_trainable_emb_info.append(name)
added_emb_info = [emb["name"] for emb in embeddings]
added_emb_info = ", ".join(added_emb_info)
print(f"Successfully add external embeddings: {added_emb_info}.", "current")
if added_trainable_emb_info:
added_trainable_emb_info = ", ".join(added_trainable_emb_info)
print(
"Successfully add trainable external embeddings: "
f"{added_trainable_emb_info}",
"current",
)
def replace_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Replace external input ids to 0.
Args:
input_ids (torch.Tensor): The input ids to be replaced.
Returns:
torch.Tensor: The replaced input ids.
"""
input_ids_fwd = input_ids.clone()
input_ids_fwd[input_ids_fwd >= self.num_embeddings] = 0
return input_ids_fwd
def replace_embeddings(
self, input_ids: torch.Tensor, embedding: torch.Tensor, external_embedding: dict
) -> torch.Tensor:
"""Replace external embedding to the embedding layer. Noted that, in
this function we use `torch.cat` to avoid inplace modification.
Args:
input_ids (torch.Tensor): The original token ids. Shape like
[LENGTH, ].
embedding (torch.Tensor): The embedding of token ids after
`replace_input_ids` function.
external_embedding (dict): The external embedding to be replaced.
Returns:
torch.Tensor: The replaced embedding.
"""
new_embedding = []
name = external_embedding["name"]
start = external_embedding["start"]
end = external_embedding["end"]
target_ids_to_replace = [i for i in range(start, end)]
ext_emb = external_embedding["embedding"]
# do not need to replace
if not (input_ids == start).any():
return embedding
# start replace
s_idx, e_idx = 0, 0
while e_idx < len(input_ids):
if input_ids[e_idx] == start:
if e_idx != 0:
# add embedding do not need to replace
new_embedding.append(embedding[s_idx:e_idx])
# check if the next embedding need to replace is valid
actually_ids_to_replace = [
int(i) for i in input_ids[e_idx : e_idx + end - start]
]
assert actually_ids_to_replace == target_ids_to_replace, (
f"Invalid 'input_ids' in position: {s_idx} to {e_idx}. "
f"Expect '{target_ids_to_replace}' for embedding "
f"'{name}' but found '{actually_ids_to_replace}'."
)
new_embedding.append(ext_emb)
s_idx = e_idx + end - start
e_idx = s_idx + 1
else:
e_idx += 1
if e_idx == len(input_ids):
new_embedding.append(embedding[s_idx:e_idx])
return torch.cat(new_embedding, dim=0)
def forward(
self, input_ids: torch.Tensor, external_embeddings: Optional[List[dict]] = None
):
"""The forward function.
Args:
input_ids (torch.Tensor): The token ids shape like [bz, LENGTH] or
[LENGTH, ].
external_embeddings (Optional[List[dict]]): The external
embeddings. If not passed, only `self.external_embeddings`
will be used. Defaults to None.
input_ids: shape like [bz, LENGTH] or [LENGTH].
"""
assert input_ids.ndim in [1, 2]
if input_ids.ndim == 1:
input_ids = input_ids.unsqueeze(0)
if external_embeddings is None and not self.external_embeddings:
return self.wrapped(input_ids)
input_ids_fwd = self.replace_input_ids(input_ids)
inputs_embeds = self.wrapped(input_ids_fwd)
vecs = []
if external_embeddings is None:
external_embeddings = []
elif isinstance(external_embeddings, dict):
external_embeddings = [external_embeddings]
embeddings = self.external_embeddings + external_embeddings
for input_id, embedding in zip(input_ids, inputs_embeds):
new_embedding = embedding
for external_embedding in embeddings:
new_embedding = self.replace_embeddings(
input_id, new_embedding, external_embedding
)
vecs.append(new_embedding)
return torch.stack(vecs)
def add_tokens(
tokenizer,
text_encoder,
placeholder_tokens: list,
initialize_tokens: list = None,
num_vectors_per_token: int = 1,
):
"""Add token for training.
# TODO: support add tokens as dict, then we can load pretrained tokens.
"""
if initialize_tokens is not None:
assert len(initialize_tokens) == len(
placeholder_tokens
), "placeholder_token should be the same length as initialize_token"
for ii in range(len(placeholder_tokens)):
tokenizer.add_placeholder_token(
placeholder_tokens[ii], num_vec_per_token=num_vectors_per_token
)
# text_encoder.set_embedding_layer()
embedding_layer = text_encoder.text_model.embeddings.token_embedding
text_encoder.text_model.embeddings.token_embedding = EmbeddingLayerWithFixes(
embedding_layer
)
embedding_layer = text_encoder.text_model.embeddings.token_embedding
assert embedding_layer is not None, (
"Do not support get embedding layer for current text encoder. "
"Please check your configuration."
)
initialize_embedding = []
if initialize_tokens is not None:
for ii in range(len(placeholder_tokens)):
init_id = tokenizer(initialize_tokens[ii]).input_ids[1]
temp_embedding = embedding_layer.weight[init_id]
initialize_embedding.append(
temp_embedding[None, ...].repeat(num_vectors_per_token, 1)
)
else:
for ii in range(len(placeholder_tokens)):
init_id = tokenizer("a").input_ids[1]
temp_embedding = embedding_layer.weight[init_id]
len_emb = temp_embedding.shape[0]
init_weight = (torch.rand(num_vectors_per_token, len_emb) - 0.5) / 2.0
initialize_embedding.append(init_weight)
# initialize_embedding = torch.cat(initialize_embedding,dim=0)
token_info_all = []
for ii in range(len(placeholder_tokens)):
token_info = tokenizer.get_token_info(placeholder_tokens[ii])
token_info["embedding"] = initialize_embedding[ii]
token_info["trainable"] = True
token_info_all.append(token_info)
embedding_layer.add_embeddings(token_info_all)

View File

@ -0,0 +1,933 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, logging
from diffusers.models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, \
TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
from .unet_2d_blocks import (
CrossAttnDownBlock2D,
DownBlock2D,
get_down_block,
get_mid_block,
get_up_block
)
from .unet_2d_condition import UNet2DConditionModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class BrushNetOutput(BaseOutput):
"""
The output of [`BrushNetModel`].
Args:
up_block_res_samples (`tuple[torch.Tensor]`):
A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
used to condition the original UNet's upsampling activations.
down_block_res_samples (`tuple[torch.Tensor]`):
A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
used to condition the original UNet's downsampling activations.
mid_down_block_re_sample (`torch.Tensor`):
The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
`(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
Output can be used to condition the original UNet's middle block activation.
"""
up_block_res_samples: Tuple[torch.Tensor]
down_block_res_samples: Tuple[torch.Tensor]
mid_block_res_sample: torch.Tensor
class BrushNetModel(ModelMixin, ConfigMixin):
"""
A BrushNet model.
Args:
in_channels (`int`, defaults to 4):
The number of channels in the input sample.
flip_sin_to_cos (`bool`, defaults to `True`):
Whether to flip the sin to cos in the time embedding.
freq_shift (`int`, defaults to 0):
The frequency shift to apply to the time embedding.
down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
`UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
The tuple of upsample blocks to use.
only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
The tuple of output channels for each block.
layers_per_block (`int`, defaults to 2):
The number of layers per block.
downsample_padding (`int`, defaults to 1):
The padding to use for the downsampling convolution.
mid_block_scale_factor (`float`, defaults to 1):
The scale factor to use for the mid block.
act_fn (`str`, defaults to "silu"):
The activation function to use.
norm_num_groups (`int`, *optional*, defaults to 32):
The number of groups to use for the normalization. If None, normalization and activation layers is skipped
in post-processing.
norm_eps (`float`, defaults to 1e-5):
The epsilon to use for the normalization.
cross_attention_dim (`int`, defaults to 1280):
The dimension of the cross attention features.
transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
[`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
[`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
encoder_hid_dim (`int`, *optional*, defaults to None):
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
dimension to `cross_attention_dim`.
encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
The dimension of the attention heads.
use_linear_projection (`bool`, defaults to `False`):
class_embed_type (`str`, *optional*, defaults to `None`):
The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
`"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
addition_embed_type (`str`, *optional*, defaults to `None`):
Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
"text". "text" will use the `TextTimeEmbedding` layer.
num_class_embeds (`int`, *optional*, defaults to 0):
Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
class conditioning with `class_embed_type` equal to `None`.
upcast_attention (`bool`, defaults to `False`):
resnet_time_scale_shift (`str`, defaults to `"default"`):
Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
`class_embed_type="projection"`.
brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
The tuple of output channel for each block in the `conditioning_embedding` layer.
global_pool_conditions (`bool`, defaults to `False`):
TODO(Patrick) - unused parameter.
addition_embed_type_num_heads (`int`, defaults to 64):
The number of heads to use for the `TextTimeEmbedding` layer.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 4,
conditioning_channels: int = 5,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str, ...] = (
"UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
brushnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
addition_embed_type_num_heads: int = 64,
):
super().__init__()
# If `num_attention_heads` is not defined (which is the case for most models)
# it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
# The reason for this behavior is to correct for incorrectly named variables that were introduced
# when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
# Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
# which is why we correct for the naming here.
num_attention_heads = num_attention_heads or attention_head_dim
# Check inputs
if len(down_block_types) != len(up_block_types):
raise ValueError(
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
)
if len(block_out_channels) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
)
if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
)
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
raise ValueError(
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
)
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
# input
conv_in_kernel = 3
conv_in_padding = (conv_in_kernel - 1) // 2
self.conv_in_condition = nn.Conv2d(
in_channels + conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel,
padding=conv_in_padding
)
# time
time_embed_dim = block_out_channels[0] * 4
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(
timestep_input_dim,
time_embed_dim,
act_fn=act_fn,
)
if encoder_hid_dim_type is None and encoder_hid_dim is not None:
encoder_hid_dim_type = "text_proj"
self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
if encoder_hid_dim is None and encoder_hid_dim_type is not None:
raise ValueError(
f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
)
if encoder_hid_dim_type == "text_proj":
self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
elif encoder_hid_dim_type == "text_image_proj":
# image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
self.encoder_hid_proj = TextImageProjection(
text_embed_dim=encoder_hid_dim,
image_embed_dim=cross_attention_dim,
cross_attention_dim=cross_attention_dim,
)
elif encoder_hid_dim_type is not None:
raise ValueError(
f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
)
else:
self.encoder_hid_proj = None
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
elif class_embed_type == "projection":
if projection_class_embeddings_input_dim is None:
raise ValueError(
"`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
)
# The projection `class_embed_type` is the same as the timestep `class_embed_type` except
# 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
# 2. it projects from an arbitrary input dimension.
#
# Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
# When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
# As a result, `TimestepEmbedding` can be passed arbitrary vectors.
self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
else:
self.class_embedding = None
if addition_embed_type == "text":
if encoder_hid_dim is not None:
text_time_embedding_from_dim = encoder_hid_dim
else:
text_time_embedding_from_dim = cross_attention_dim
self.add_embedding = TextTimeEmbedding(
text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
)
elif addition_embed_type == "text_image":
# text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
# they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
# case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
self.add_embedding = TextImageTimeEmbedding(
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
)
elif addition_embed_type == "text_time":
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
elif addition_embed_type is not None:
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
self.down_blocks = nn.ModuleList([])
self.brushnet_down_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
brushnet_block = zero_module(brushnet_block)
self.brushnet_down_blocks.append(brushnet_block)
for i, down_block_type in enumerate(down_block_types):
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[i],
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
downsample_padding=downsample_padding,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.down_blocks.append(down_block)
for _ in range(layers_per_block):
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
brushnet_block = zero_module(brushnet_block)
self.brushnet_down_blocks.append(brushnet_block)
if not is_final_block:
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
brushnet_block = zero_module(brushnet_block)
self.brushnet_down_blocks.append(brushnet_block)
# mid
mid_block_channel = block_out_channels[-1]
brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
brushnet_block = zero_module(brushnet_block)
self.brushnet_mid_block = brushnet_block
self.mid_block = get_mid_block(
mid_block_type,
transformer_layers_per_block=transformer_layers_per_block[-1],
in_channels=mid_block_channel,
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
num_attention_heads=num_attention_heads[-1],
resnet_groups=norm_num_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
# count how many layers upsample the images
self.num_upsamplers = 0
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_num_attention_heads = list(reversed(num_attention_heads))
reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
self.up_blocks = nn.ModuleList([])
self.brushnet_up_blocks = nn.ModuleList([])
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resolution_idx=i,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
num_attention_heads=reversed_num_attention_heads[i],
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
for _ in range(layers_per_block + 1):
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
brushnet_block = zero_module(brushnet_block)
self.brushnet_up_blocks.append(brushnet_block)
if not is_final_block:
brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
brushnet_block = zero_module(brushnet_block)
self.brushnet_up_blocks.append(brushnet_block)
@classmethod
def from_unet(
cls,
unet: UNet2DConditionModel,
brushnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
load_weights_from_unet: bool = True,
conditioning_channels: int = 5,
):
r"""
Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
Parameters:
unet (`UNet2DConditionModel`):
The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
where applicable.
"""
transformer_layers_per_block = (
unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
)
encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
addition_time_embed_dim = (
unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
)
brushnet = cls(
in_channels=unet.config.in_channels,
conditioning_channels=conditioning_channels,
flip_sin_to_cos=unet.config.flip_sin_to_cos,
freq_shift=unet.config.freq_shift,
# down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
down_block_types=["CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D", ],
# mid_block_type='MidBlock2D',
mid_block_type="UNetMidBlock2DCrossAttn",
# up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
only_cross_attention=unet.config.only_cross_attention,
block_out_channels=unet.config.block_out_channels,
layers_per_block=unet.config.layers_per_block,
downsample_padding=unet.config.downsample_padding,
mid_block_scale_factor=unet.config.mid_block_scale_factor,
act_fn=unet.config.act_fn,
norm_num_groups=unet.config.norm_num_groups,
norm_eps=unet.config.norm_eps,
cross_attention_dim=unet.config.cross_attention_dim,
transformer_layers_per_block=transformer_layers_per_block,
encoder_hid_dim=encoder_hid_dim,
encoder_hid_dim_type=encoder_hid_dim_type,
attention_head_dim=unet.config.attention_head_dim,
num_attention_heads=unet.config.num_attention_heads,
use_linear_projection=unet.config.use_linear_projection,
class_embed_type=unet.config.class_embed_type,
addition_embed_type=addition_embed_type,
addition_time_embed_dim=addition_time_embed_dim,
num_class_embeds=unet.config.num_class_embeds,
upcast_attention=unet.config.upcast_attention,
resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
)
if load_weights_from_unet:
conv_in_condition_weight = torch.zeros_like(brushnet.conv_in_condition.weight)
conv_in_condition_weight[:, :4, ...] = unet.conv_in.weight
conv_in_condition_weight[:, 4:8, ...] = unet.conv_in.weight
brushnet.conv_in_condition.weight = torch.nn.Parameter(conv_in_condition_weight)
brushnet.conv_in_condition.bias = unet.conv_in.bias
brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
if brushnet.class_embedding:
brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(), strict=False)
brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(), strict=False)
brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(), strict=False)
return brushnet.to(unet.dtype)
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
r"""
Enable sliced attention computation.
When this option is enabled, the attention module splits the input tensor in slices to compute attention in
several steps. This is useful for saving some memory in exchange for a small decrease in speed.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
`"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_sliceable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_sliceable_dims(module)
num_sliceable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_sliceable_layers * [1]
slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
brushnet_cond: torch.FloatTensor,
conditioning_scale: float = 1.0,
class_labels: Optional[torch.Tensor] = None,
timestep_cond: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guess_mode: bool = False,
return_dict: bool = True,
) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
"""
The [`BrushNetModel`] forward method.
Args:
sample (`torch.FloatTensor`):
The noisy input tensor.
timestep (`Union[torch.Tensor, float, int]`):
The number of timesteps to denoise an input.
encoder_hidden_states (`torch.Tensor`):
The encoder hidden states.
brushnet_cond (`torch.FloatTensor`):
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
conditioning_scale (`float`, defaults to `1.0`):
The scale factor for BrushNet outputs.
class_labels (`torch.Tensor`, *optional*, defaults to `None`):
Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
embeddings.
attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
added_cond_kwargs (`dict`):
Additional conditions for the Stable Diffusion XL UNet.
cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
guess_mode (`bool`, defaults to `False`):
In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
Returns:
[`~models.brushnet.BrushNetOutput`] **or** `tuple`:
If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
returned where the first element is the sample tensor.
"""
# check channel order
channel_order = self.config.brushnet_conditioning_channel_order
if channel_order == "rgb":
# in rgb order by default
...
elif channel_order == "bgr":
brushnet_cond = torch.flip(brushnet_cond, dims=[1])
else:
raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=sample.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
aug_emb = None
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
if self.config.addition_embed_type is not None:
if self.config.addition_embed_type == "text":
aug_emb = self.add_embedding(encoder_hidden_states)
elif self.config.addition_embed_type == "text_time":
if "text_embeds" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if "time_ids" not in added_cond_kwargs:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds)
emb = emb + aug_emb if aug_emb is not None else emb
# 2. pre-process
brushnet_cond = torch.concat([sample, brushnet_cond], 1)
sample = self.conv_in_condition(brushnet_cond)
# 3. down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
down_block_res_samples += res_samples
# 4. PaintingNet down blocks
brushnet_down_block_res_samples = ()
for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
down_block_res_sample = brushnet_down_block(down_block_res_sample)
brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
# 5. mid
if self.mid_block is not None:
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
sample = self.mid_block(
sample,
emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
cross_attention_kwargs=cross_attention_kwargs,
)
else:
sample = self.mid_block(sample, emb)
# 6. BrushNet mid blocks
brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
# 7. up
up_block_res_samples = ()
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample, up_res_samples = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
upsample_size=upsample_size,
attention_mask=attention_mask,
return_res_samples=True
)
else:
sample, up_res_samples = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
return_res_samples=True
)
up_block_res_samples += up_res_samples
# 8. BrushNet up blocks
brushnet_up_block_res_samples = ()
for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
up_block_res_sample = brushnet_up_block(up_block_res_sample)
brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
# 6. scaling
if guess_mode and not self.config.global_pool_conditions:
scales = torch.logspace(-1, 0,
len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples),
device=sample.device) # 0.1 to 1.0
scales = scales * conditioning_scale
brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples,
scales[:len(
brushnet_down_block_res_samples)])]
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples,
scales[
len(brushnet_down_block_res_samples) + 1:])]
else:
brushnet_down_block_res_samples = [sample * conditioning_scale for sample in
brushnet_down_block_res_samples]
brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
if self.config.global_pool_conditions:
brushnet_down_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
]
brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
brushnet_up_block_res_samples = [
torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
]
if not return_dict:
return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
return BrushNetOutput(
down_block_res_samples=brushnet_down_block_res_samples,
mid_block_res_sample=brushnet_mid_block_res_sample,
up_block_res_samples=brushnet_up_block_res_samples
)
def zero_module(module):
for p in module.parameters():
nn.init.zeros_(p)
return module

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -8,6 +8,7 @@ from iopaint.download import scan_models
from iopaint.helper import switch_mps_device from iopaint.helper import switch_mps_device
from iopaint.model import models, ControlNet, SD, SDXL from iopaint.model import models, ControlNet, SD, SDXL
from iopaint.model.brushnet.brushnet_wrapper import BrushNetWrapper from iopaint.model.brushnet.brushnet_wrapper import BrushNetWrapper
from iopaint.model.power_paint.power_paint_v2 import PowerPaintV2
from iopaint.model.utils import torch_gc, is_local_files_only from iopaint.model.utils import torch_gc, is_local_files_only
from iopaint.schema import InpaintRequest, ModelInfo, ModelType from iopaint.schema import InpaintRequest, ModelInfo, ModelType
@ -23,9 +24,9 @@ class ModelManager:
self.enable_controlnet = kwargs.get("enable_controlnet", False) self.enable_controlnet = kwargs.get("enable_controlnet", False)
controlnet_method = kwargs.get("controlnet_method", None) controlnet_method = kwargs.get("controlnet_method", None)
if ( if (
controlnet_method is None controlnet_method is None
and name in self.available_models and name in self.available_models
and self.available_models[name].support_controlnet and self.available_models[name].support_controlnet
): ):
controlnet_method = self.available_models[name].controlnets[0] controlnet_method = self.available_models[name].controlnets[0]
self.controlnet_method = controlnet_method self.controlnet_method = controlnet_method
@ -33,6 +34,8 @@ class ModelManager:
self.enable_brushnet = kwargs.get("enable_brushnet", False) self.enable_brushnet = kwargs.get("enable_brushnet", False)
self.brushnet_method = kwargs.get("brushnet_method", None) self.brushnet_method = kwargs.get("brushnet_method", None)
self.enable_powerpaint_v2 = kwargs.get("enable_powerpaint_v2", False)
self.model = self.init_model(name, device, **kwargs) self.model = self.init_model(name, device, **kwargs)
@property @property
@ -62,6 +65,9 @@ class ModelManager:
if model_info.support_brushnet and self.enable_brushnet: if model_info.support_brushnet and self.enable_brushnet:
return BrushNetWrapper(device, **kwargs) return BrushNetWrapper(device, **kwargs)
if model_info.support_powerpaint_v2 and self.enable_powerpaint_v2:
return PowerPaintV2(device, **kwargs)
if model_info.name in models: if model_info.name in models:
return models[name](device, **kwargs) return models[name](device, **kwargs)
@ -91,10 +97,12 @@ class ModelManager:
Returns: Returns:
BGR image BGR image
""" """
if not config.enable_brushnet: if config.enable_controlnet:
self.switch_controlnet_method(config) self.switch_controlnet_method(config)
if not config.enable_controlnet: if config.enable_brushnet:
self.switch_brushnet_method(config) self.switch_brushnet_method(config)
self.enable_disable_powerpaint_v2(config)
self.enable_disable_freeu(config) self.enable_disable_freeu(config)
self.enable_disable_lcm_lora(config) self.enable_disable_lcm_lora(config)
return self.model(image, mask, config).astype(np.uint8) return self.model(image, mask, config).astype(np.uint8)
@ -113,9 +121,9 @@ class ModelManager:
self.name = new_name self.name = new_name
if ( if (
self.available_models[new_name].support_controlnet self.available_models[new_name].support_controlnet
and self.controlnet_method and self.controlnet_method
not in self.available_models[new_name].controlnets not in self.available_models[new_name].controlnets
): ):
self.controlnet_method = self.available_models[new_name].controlnets[0] self.controlnet_method = self.available_models[new_name].controlnets[0]
try: try:
@ -140,9 +148,9 @@ class ModelManager:
return return
if ( if (
self.enable_brushnet self.enable_brushnet
and config.brushnet_method and config.brushnet_method
and self.brushnet_method != config.brushnet_method and self.brushnet_method != config.brushnet_method
): ):
old_brushnet_method = self.brushnet_method old_brushnet_method = self.brushnet_method
self.brushnet_method = config.brushnet_method self.brushnet_method = config.brushnet_method
@ -180,9 +188,9 @@ class ModelManager:
return return
if ( if (
self.enable_controlnet self.enable_controlnet
and config.controlnet_method and config.controlnet_method
and self.controlnet_method != config.controlnet_method and self.controlnet_method != config.controlnet_method
): ):
old_controlnet_method = self.controlnet_method old_controlnet_method = self.controlnet_method
self.controlnet_method = config.controlnet_method self.controlnet_method = config.controlnet_method
@ -213,6 +221,25 @@ class ModelManager:
else: else:
logger.info(f"Enable controlnet: {config.controlnet_method}") logger.info(f"Enable controlnet: {config.controlnet_method}")
def enable_disable_powerpaint_v2(self, config: InpaintRequest):
if not self.available_models[self.name].support_powerpaint_v2:
return
if self.enable_powerpaint_v2 != config.enable_powerpaint_v2:
self.enable_powerpaint_v2 = config.enable_powerpaint_v2
pipe_components = {"vae": self.model.model.vae}
self.model = self.init_model(
self.name,
switch_mps_device(self.name, self.device),
pipe_components=pipe_components,
**self.kwargs,
)
if config.enable_powerpaint_v2:
logger.info("Enable PowerPaintV2")
else:
logger.info("Disable PowerPaintV2")
def enable_disable_freeu(self, config: InpaintRequest): def enable_disable_freeu(self, config: InpaintRequest):
if str(self.model.device) == "mps": if str(self.model.device) == "mps":
return return

View File

@ -119,6 +119,13 @@ class ModelInfo(BaseModel):
ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD,
] ]
@computed_field
@property
def support_powerpaint_v2(self) -> bool:
return self.model_type in [
ModelType.DIFFUSERS_SD,
]
@computed_field @computed_field
@property @property
def support_freeu(self) -> bool: def support_freeu(self) -> bool:
@ -225,8 +232,9 @@ class FREEUConfig(BaseModel):
b2: float = 1.4 b2: float = 1.4
class PowerPaintTask(str, Enum): class PowerPaintTask(Choices):
text_guided = "text-guided" text_guided = "text-guided"
context_aware = "context-aware"
shape_guided = "shape-guided" shape_guided = "shape-guided"
object_remove = "object-remove" object_remove = "object-remove"
outpainting = "outpainting" outpainting = "outpainting"
@ -387,12 +395,13 @@ class InpaintRequest(BaseModel):
# BrushNet # BrushNet
enable_brushnet: bool = Field(False, description="Enable brushnet") enable_brushnet: bool = Field(False, description="Enable brushnet")
brushnet_method: str = Field( brushnet_method: str = Field(SD_BRUSHNET_CHOICES[0], description="Brushnet method")
SD_BRUSHNET_CHOICES[0], description="Brushnet method" brushnet_conditioning_scale: float = Field(
1.0, description="brushnet conditioning scale", ge=0.0, le=1.0
) )
brushnet_conditioning_scale: float = Field(1.0, description="brushnet conditioning scale", ge=0.0, le=1.0)
# PowerPaint # PowerPaint
enable_powerpaint_v2: bool = Field(False, description="Enable PowerPaint v2")
powerpaint_task: PowerPaintTask = Field( powerpaint_task: PowerPaintTask = Field(
PowerPaintTask.text_guided, description="PowerPaint task" PowerPaintTask.text_guided, description="PowerPaint task"
) )
@ -403,8 +412,8 @@ class InpaintRequest(BaseModel):
le=1.0, le=1.0,
) )
@model_validator(mode='after') @model_validator(mode="after")
def validate_field(cls, values: 'InpaintRequest'): def validate_field(cls, values: "InpaintRequest"):
if values.sd_seed == -1: if values.sd_seed == -1:
values.sd_seed = random.randint(1, 99999999) values.sd_seed = random.randint(1, 99999999)
logger.info(f"Generate random seed: {values.sd_seed}") logger.info(f"Generate random seed: {values.sd_seed}")

View File

@ -10,7 +10,7 @@ import pytest
import torch import torch
from iopaint.model_manager import ModelManager from iopaint.model_manager import ModelManager
from iopaint.schema import HDStrategy, SDSampler, FREEUConfig from iopaint.schema import HDStrategy, SDSampler, FREEUConfig, PowerPaintTask
current_dir = Path(__file__).parent.absolute().resolve() current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / "result" save_dir = current_dir / "result"
@ -35,7 +35,7 @@ def test_runway_brushnet(device, sampler):
sd_freeu=True, sd_freeu=True,
sd_freeu_config=FREEUConfig(), sd_freeu_config=FREEUConfig(),
enable_brushnet=True, enable_brushnet=True,
brushnet_method=SD_BRUSHNET_CHOICES[0] brushnet_method=SD_BRUSHNET_CHOICES[0],
) )
cfg.sd_sampler = sampler cfg.sd_sampler = sampler
@ -49,38 +49,64 @@ def test_runway_brushnet(device, sampler):
@pytest.mark.parametrize("device", ["cuda", "mps"]) @pytest.mark.parametrize("device", ["cuda", "mps"])
@pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m_karras]) @pytest.mark.parametrize("sampler", [SDSampler.dpm_plus_plus_2m])
@pytest.mark.parametrize( def test_runway_powerpaint_v2(device, sampler):
"name",
[
"v1-5-pruned-emaonly.safetensors",
],
)
def test_brushnet_local_file_path(device, sampler, name):
sd_steps = check_device(device) sd_steps = check_device(device)
model = ModelManager( model = ModelManager(
name=name, name="runwayml/stable-diffusion-v1-5",
device=torch.device(device), device=torch.device(device),
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
cpu_offload=False,
) )
cfg = get_config(
strategy=HDStrategy.ORIGINAL,
prompt="face of a fox, sitting on a bench",
sd_steps=sd_steps,
sd_seed=1234,
enable_brushnet=True,
brushnet_method=SD_BRUSHNET_CHOICES[1]
)
cfg.sd_sampler = sampler
assert_equal( tasks = {
model, PowerPaintTask.text_guided: {
cfg, "prompt": "face of a fox, sitting on a bench",
f"brushnet_segmentation_mask_{device}.png", "scale": 7.5,
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", },
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", PowerPaintTask.context_aware: {
fx=1, "prompt": "face of a fox, sitting on a bench",
fy=1, "scale": 7.5,
) },
PowerPaintTask.shape_guided: {
"prompt": "face of a fox, sitting on a bench",
"scale": 7.5,
},
PowerPaintTask.object_remove: {
"prompt": "",
"scale": 12,
},
PowerPaintTask.outpainting: {
"prompt": "",
"scale": 7.5,
},
}
for task, data in tasks.items():
cfg = get_config(
strategy=HDStrategy.ORIGINAL,
prompt=data["prompt"],
negative_prompt="out of frame, lowres, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, disfigured, gross proportions, malformed limbs, watermark, signature",
sd_steps=sd_steps,
sd_guidance_scale=data["scale"],
enable_powerpaint_v2=True,
powerpaint_task=task,
sd_sampler=sampler,
sd_mask_blur=11,
sd_seed=42,
# sd_keep_unmasked_area=False
)
if task == PowerPaintTask.outpainting:
cfg.use_extender = True
cfg.extender_x = -128
cfg.extender_y = -128
cfg.extender_width = 768
cfg.extender_height = 768
assert_equal(
model,
cfg,
f"powerpaint_v2_{device}_{task}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)

View File

@ -3,18 +3,17 @@ import os
from iopaint.tests.utils import current_dir, check_device from iopaint.tests.utils import current_dir, check_device
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from pathlib import Path
import pytest import pytest
import torch import torch
from iopaint.model_manager import ModelManager from iopaint.model_manager import ModelManager
from iopaint.schema import HDStrategy, SDSampler from iopaint.schema import SDSampler
from iopaint.tests.test_model import get_config, assert_equal from iopaint.tests.test_model import get_config, assert_equal
@pytest.mark.parametrize("name", ["runwayml/stable-diffusion-inpainting"]) @pytest.mark.parametrize("name", ["runwayml/stable-diffusion-inpainting"])
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) @pytest.mark.parametrize("device", ["cuda", "mps"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"rect", "rect",
[ [
@ -23,7 +22,7 @@ from iopaint.tests.test_model import get_config, assert_equal
[128, 0, 512 - 128 + 100, 512], [128, 0, 512 - 128 + 100, 512],
[-100, 0, 512 - 128 + 100, 512], [-100, 0, 512 - 128 + 100, 512],
[0, 0, 512, 512 + 200], [0, 0, 512, 512 + 200],
[0, 0, 512 + 200, 512], [256, 0, 512 + 200, 512],
[-100, -100, 512 + 200, 512 + 200], [-100, -100, 512 + 200, 512 + 200],
], ],
) )
@ -58,7 +57,7 @@ def test_outpainting(name, device, rect):
@pytest.mark.parametrize("name", ["kandinsky-community/kandinsky-2-2-decoder-inpaint"]) @pytest.mark.parametrize("name", ["kandinsky-community/kandinsky-2-2-decoder-inpaint"])
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) @pytest.mark.parametrize("device", ["cuda", "mps"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"rect", "rect",
[ [
@ -99,7 +98,7 @@ def test_kandinsky_outpainting(name, device, rect):
@pytest.mark.parametrize("name", ["Sanster/PowerPaint-V1-stable-diffusion-inpainting"]) @pytest.mark.parametrize("name", ["Sanster/PowerPaint-V1-stable-diffusion-inpainting"])
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) @pytest.mark.parametrize("device", ["cuda", "mps"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"rect", "rect",
[ [
@ -114,7 +113,7 @@ def test_powerpaint_outpainting(name, device, rect):
device=torch.device(device), device=torch.device(device),
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
low_mem=True low_mem=True,
) )
cfg = get_config( cfg = get_config(
prompt="a dog sitting on a bench in the park", prompt="a dog sitting on a bench in the park",

View File

@ -3,9 +3,8 @@ import cv2
import pytest import pytest
import torch import torch
from iopaint.helper import encode_pil_to_base64
from iopaint.schema import LDMSampler, HDStrategy, InpaintRequest, SDSampler from iopaint.schema import LDMSampler, HDStrategy, InpaintRequest, SDSampler
from PIL import Image import numpy as np
current_dir = Path(__file__).parent.absolute().resolve() current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / "result" save_dir = current_dir / "result"
@ -32,6 +31,7 @@ def assert_equal(
): ):
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p) img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
print(f"Input image shape: {img.shape}") print(f"Input image shape: {img.shape}")
res = model(img, mask, config) res = model(img, mask, config)
ok = cv2.imwrite( ok = cv2.imwrite(
str(save_dir / gt_name), str(save_dir / gt_name),