2024-04-24 14:22:29 +02:00
import inspect
from typing import Any , Callable , Dict , List , Optional , Union
import numpy as np
import PIL . Image
import torch
import torch . nn . functional as F
2024-04-29 16:20:44 +02:00
from diffusers import StableDiffusionMixin , UNet2DConditionModel
from transformers import (
CLIPImageProcessor ,
CLIPTextModel ,
CLIPTokenizer ,
CLIPVisionModelWithProjection ,
)
2024-04-24 14:22:29 +02:00
from diffusers . image_processor import PipelineImageInput , VaeImageProcessor
2024-04-29 16:20:44 +02:00
from diffusers . loaders import (
FromSingleFileMixin ,
IPAdapterMixin ,
LoraLoaderMixin ,
TextualInversionLoaderMixin ,
)
2024-04-24 14:22:29 +02:00
from diffusers . models import AutoencoderKL , ImageProjection
from diffusers . models . lora import adjust_lora_scale_text_encoder
from diffusers . schedulers import KarrasDiffusionSchedulers
from diffusers . utils import (
USE_PEFT_BACKEND ,
deprecate ,
logging ,
replace_example_docstring ,
scale_lora_layers ,
unscale_lora_layers ,
)
2024-04-29 16:20:44 +02:00
from diffusers . utils . torch_utils import (
is_compiled_module ,
is_torch_version ,
randn_tensor ,
)
2024-04-24 14:22:29 +02:00
from diffusers . pipelines . pipeline_utils import DiffusionPipeline
2024-04-29 16:20:44 +02:00
from diffusers . pipelines . stable_diffusion . pipeline_output import (
StableDiffusionPipelineOutput ,
)
from diffusers . pipelines . stable_diffusion . safety_checker import (
StableDiffusionSafetyChecker ,
)
2024-04-24 14:22:29 +02:00
from . BrushNet_CA import BrushNetModel
logger = logging . get_logger ( __name__ ) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples :
` ` ` py
from diffusers import StableDiffusionBrushNetPipeline , BrushNetModel , UniPCMultistepScheduler
from diffusers . utils import load_image
import torch
import cv2
import numpy as np
from PIL import Image
base_model_path = " runwayml/stable-diffusion-v1-5 "
brushnet_path = " ckpt_path "
brushnet = BrushNetModel . from_pretrained ( brushnet_path , torch_dtype = torch . float16 )
pipe = StableDiffusionBrushNetPipeline . from_pretrained (
base_model_path , brushnet = brushnet , torch_dtype = torch . float16 , low_cpu_mem_usage = False
)
# speed up diffusion process with faster scheduler and memory optimization
pipe . scheduler = UniPCMultistepScheduler . from_config ( pipe . scheduler . config )
# remove following line if xformers is not installed or when using Torch 2.0.
# pipe.enable_xformers_memory_efficient_attention()
# memory optimization.
pipe . enable_model_cpu_offload ( )
image_path = " examples/brushnet/src/test_image.jpg "
mask_path = " examples/brushnet/src/test_mask.jpg "
caption = " A cake on the table. "
init_image = cv2 . imread ( image_path )
mask_image = 1. * ( cv2 . imread ( mask_path ) . sum ( - 1 ) > 255 ) [ : , : , np . newaxis ]
init_image = init_image * ( 1 - mask_image )
init_image = Image . fromarray ( init_image . astype ( np . uint8 ) ) . convert ( " RGB " )
mask_image = Image . fromarray ( mask_image . astype ( np . uint8 ) . repeat ( 3 , - 1 ) * 255 ) . convert ( " RGB " )
generator = torch . Generator ( " cuda " ) . manual_seed ( 1234 )
image = pipe (
caption ,
init_image ,
mask_image ,
num_inference_steps = 50 ,
generator = generator ,
paintingnet_conditioning_scale = 1.0
) . images [ 0 ]
image . save ( " output.png " )
` ` `
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps (
2024-04-29 16:20:44 +02:00
scheduler ,
num_inference_steps : Optional [ int ] = None ,
device : Optional [ Union [ str , torch . device ] ] = None ,
timesteps : Optional [ List [ int ] ] = None ,
* * kwargs ,
2024-04-24 14:22:29 +02:00
) :
"""
Calls the scheduler ' s `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps . Any kwargs will be supplied to ` scheduler . set_timesteps ` .
Args :
scheduler ( ` SchedulerMixin ` ) :
The scheduler to get timesteps from .
num_inference_steps ( ` int ` ) :
The number of diffusion steps used when generating samples with a pre - trained model . If used ,
` timesteps ` must be ` None ` .
device ( ` str ` or ` torch . device ` , * optional * ) :
The device to which the timesteps should be moved to . If ` None ` , the timesteps are not moved .
timesteps ( ` List [ int ] ` , * optional * ) :
Custom timesteps used to support arbitrary spacing between timesteps . If ` None ` , then the default
timestep spacing strategy of the scheduler is used . If ` timesteps ` is passed , ` num_inference_steps `
must be ` None ` .
Returns :
` Tuple [ torch . Tensor , int ] ` : A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps .
"""
if timesteps is not None :
2024-04-29 16:20:44 +02:00
accepts_timesteps = " timesteps " in set (
inspect . signature ( scheduler . set_timesteps ) . parameters . keys ( )
)
2024-04-24 14:22:29 +02:00
if not accepts_timesteps :
raise ValueError (
f " The current scheduler class { scheduler . __class__ } ' s `set_timesteps` does not support custom "
f " timestep schedules. Please check whether you are using the correct scheduler. "
)
scheduler . set_timesteps ( timesteps = timesteps , device = device , * * kwargs )
timesteps = scheduler . timesteps
num_inference_steps = len ( timesteps )
else :
scheduler . set_timesteps ( num_inference_steps , device = device , * * kwargs )
timesteps = scheduler . timesteps
return timesteps , num_inference_steps
class StableDiffusionPowerPaintBrushNetPipeline (
DiffusionPipeline ,
StableDiffusionMixin ,
TextualInversionLoaderMixin ,
LoraLoaderMixin ,
IPAdapterMixin ,
FromSingleFileMixin ,
) :
r """
Pipeline for text - to - image generation using Stable Diffusion with BrushNet guidance .
This model inherits from [ ` DiffusionPipeline ` ] . Check the superclass documentation for the generic methods
implemented for all pipelines ( downloading , saving , running on a particular device , etc . ) .
The pipeline also inherits the following loading methods :
- [ ` ~ loaders . TextualInversionLoaderMixin . load_textual_inversion ` ] for loading textual inversion embeddings
- [ ` ~ loaders . LoraLoaderMixin . load_lora_weights ` ] for loading LoRA weights
- [ ` ~ loaders . LoraLoaderMixin . save_lora_weights ` ] for saving LoRA weights
- [ ` ~ loaders . FromSingleFileMixin . from_single_file ` ] for loading ` . ckpt ` files
- [ ` ~ loaders . IPAdapterMixin . load_ip_adapter ` ] for loading IP Adapters
Args :
vae ( [ ` AutoencoderKL ` ] ) :
Variational Auto - Encoder ( VAE ) model to encode and decode images to and from latent representations .
text_encoder ( [ ` ~ transformers . CLIPTextModel ` ] ) :
Frozen text - encoder ( [ clip - vit - large - patch14 ] ( https : / / huggingface . co / openai / clip - vit - large - patch14 ) ) .
tokenizer ( [ ` ~ transformers . CLIPTokenizer ` ] ) :
A ` CLIPTokenizer ` to tokenize text .
unet ( [ ` UNet2DConditionModel ` ] ) :
A ` UNet2DConditionModel ` to denoise the encoded image latents .
brushnet ( [ ` BrushNetModel ` ] ` ) :
Provides additional conditioning to the ` unet ` during the denoising process .
scheduler ( [ ` SchedulerMixin ` ] ) :
A scheduler to be used in combination with ` unet ` to denoise the encoded image latents . Can be one of
[ ` DDIMScheduler ` ] , [ ` LMSDiscreteScheduler ` ] , or [ ` PNDMScheduler ` ] .
safety_checker ( [ ` StableDiffusionSafetyChecker ` ] ) :
Classification module that estimates whether generated images could be considered offensive or harmful .
Please refer to the [ model card ] ( https : / / huggingface . co / runwayml / stable - diffusion - v1 - 5 ) for more details
about a model ' s potential harms.
feature_extractor ( [ ` ~ transformers . CLIPImageProcessor ` ] ) :
A ` CLIPImageProcessor ` to extract features from generated images ; used as inputs to the ` safety_checker ` .
"""
model_cpu_offload_seq = " text_encoder->image_encoder->unet->vae "
_optional_components = [ " safety_checker " , " feature_extractor " , " image_encoder " ]
_exclude_from_cpu_offload = [ " safety_checker " ]
_callback_tensor_inputs = [ " latents " , " prompt_embeds " , " negative_prompt_embeds " ]
def __init__ (
2024-04-29 16:20:44 +02:00
self ,
vae : AutoencoderKL ,
text_encoder : CLIPTextModel ,
text_encoder_brushnet : CLIPTextModel ,
tokenizer : CLIPTokenizer ,
unet : UNet2DConditionModel ,
brushnet : BrushNetModel ,
scheduler : KarrasDiffusionSchedulers ,
safety_checker : StableDiffusionSafetyChecker ,
feature_extractor : CLIPImageProcessor ,
image_encoder : CLIPVisionModelWithProjection = None ,
requires_safety_checker : bool = True ,
2024-04-24 14:22:29 +02:00
) :
super ( ) . __init__ ( )
if safety_checker is None and requires_safety_checker :
logger . warning (
f " You have disabled the safety checker for { self . __class__ } by passing `safety_checker=None`. Ensure "
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered "
" results in services or applications open to the public. Both the diffusers team and Hugging Face "
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling "
" it only for use-cases that involve analyzing network behavior or auditing its results. For more "
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 . "
)
if safety_checker is not None and feature_extractor is None :
raise ValueError (
" Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety "
" checker. If you do not want to use the safety checker, you can pass ` ' safety_checker=None ' ` instead. "
)
self . register_modules (
vae = vae ,
text_encoder = text_encoder ,
text_encoder_brushnet = text_encoder_brushnet ,
tokenizer = tokenizer ,
unet = unet ,
brushnet = brushnet ,
scheduler = scheduler ,
safety_checker = safety_checker ,
feature_extractor = feature_extractor ,
image_encoder = image_encoder ,
)
self . vae_scale_factor = 2 * * ( len ( self . vae . config . block_out_channels ) - 1 )
2024-04-29 16:20:44 +02:00
self . image_processor = VaeImageProcessor (
vae_scale_factor = self . vae_scale_factor , do_convert_rgb = True
)
2024-04-24 14:22:29 +02:00
self . register_to_config ( requires_safety_checker = requires_safety_checker )
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
def _encode_prompt (
2024-04-29 16:20:44 +02:00
self ,
promptA ,
promptB ,
t ,
device ,
num_images_per_prompt ,
do_classifier_free_guidance ,
negative_promptA = None ,
negative_promptB = None ,
t_nag = None ,
prompt_embeds : Optional [ torch . FloatTensor ] = None ,
negative_prompt_embeds : Optional [ torch . FloatTensor ] = None ,
lora_scale : Optional [ float ] = None ,
2024-04-24 14:22:29 +02:00
) :
r """
Encodes the prompt into text encoder hidden states .
Args :
prompt ( ` str ` or ` List [ str ] ` , * optional * ) :
prompt to be encoded
device : ( ` torch . device ` ) :
torch device
num_images_per_prompt ( ` int ` ) :
number of images that should be generated per prompt
do_classifier_free_guidance ( ` bool ` ) :
whether to use classifier free guidance or not
negative_prompt ( ` str ` or ` List [ str ] ` , * optional * ) :
The prompt or prompts not to guide the image generation . If not defined , one has to pass
` negative_prompt_embeds ` instead . Ignored when not using guidance ( i . e . , ignored if ` guidance_scale ` is
less than ` 1 ` ) .
prompt_embeds ( ` torch . FloatTensor ` , * optional * ) :
Pre - generated text embeddings . Can be used to easily tweak text inputs , * e . g . * prompt weighting . If not
provided , text embeddings will be generated from ` prompt ` input argument .
negative_prompt_embeds ( ` torch . FloatTensor ` , * optional * ) :
Pre - generated negative text embeddings . Can be used to easily tweak text inputs , * e . g . * prompt
weighting . If not provided , negative_prompt_embeds will be generated from ` negative_prompt ` input
argument .
lora_scale ( ` float ` , * optional * ) :
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded .
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance ( self , LoraLoaderMixin ) :
self . _lora_scale = lora_scale
prompt = promptA
negative_prompt = negative_promptA
if promptA is not None and isinstance ( promptA , str ) :
batch_size = 1
elif promptA is not None and isinstance ( promptA , list ) :
batch_size = len ( promptA )
else :
batch_size = prompt_embeds . shape [ 0 ]
if prompt_embeds is None :
# textual inversion: procecss multi-vector tokens if necessary
if isinstance ( self , TextualInversionLoaderMixin ) :
promptA = self . maybe_convert_prompt ( promptA , self . tokenizer )
text_inputsA = self . tokenizer (
promptA ,
padding = " max_length " ,
max_length = self . tokenizer . model_max_length ,
truncation = True ,
return_tensors = " pt " ,
)
text_inputsB = self . tokenizer (
promptB ,
padding = " max_length " ,
max_length = self . tokenizer . model_max_length ,
truncation = True ,
return_tensors = " pt " ,
)
text_input_idsA = text_inputsA . input_ids
text_input_idsB = text_inputsB . input_ids
2024-04-29 16:20:44 +02:00
untruncated_ids = self . tokenizer (
promptA , padding = " longest " , return_tensors = " pt "
) . input_ids
2024-04-24 14:22:29 +02:00
2024-04-29 16:20:44 +02:00
if untruncated_ids . shape [ - 1 ] > = text_input_idsA . shape [
- 1
] and not torch . equal ( text_input_idsA , untruncated_ids ) :
2024-04-24 14:22:29 +02:00
removed_text = self . tokenizer . batch_decode (
2024-04-29 16:20:44 +02:00
untruncated_ids [ : , self . tokenizer . model_max_length - 1 : - 1 ]
2024-04-24 14:22:29 +02:00
)
logger . warning (
" The following part of your input was truncated because CLIP can only handle sequences up to "
f " { self . tokenizer . model_max_length } tokens: { removed_text } "
)
2024-04-29 16:20:44 +02:00
if (
hasattr ( self . text_encoder_brushnet . config , " use_attention_mask " )
and self . text_encoder_brushnet . config . use_attention_mask
) :
2024-04-24 14:22:29 +02:00
attention_mask = text_inputsA . attention_mask . to ( device )
else :
attention_mask = None
# print("text_input_idsA: ",text_input_idsA)
# print("text_input_idsB: ",text_input_idsB)
# print('t: ',t)
prompt_embedsA = self . text_encoder_brushnet (
text_input_idsA . to ( device ) ,
attention_mask = attention_mask ,
)
prompt_embedsA = prompt_embedsA [ 0 ]
prompt_embedsB = self . text_encoder_brushnet (
text_input_idsB . to ( device ) ,
attention_mask = attention_mask ,
)
prompt_embedsB = prompt_embedsB [ 0 ]
prompt_embeds = prompt_embedsA * ( t ) + ( 1 - t ) * prompt_embedsB
# print("prompt_embeds: ",prompt_embeds)
if self . text_encoder_brushnet is not None :
prompt_embeds_dtype = self . text_encoder_brushnet . dtype
elif self . unet is not None :
prompt_embeds_dtype = self . unet . dtype
else :
prompt_embeds_dtype = prompt_embeds . dtype
prompt_embeds = prompt_embeds . to ( dtype = prompt_embeds_dtype , device = device )
bs_embed , seq_len , _ = prompt_embeds . shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds . repeat ( 1 , num_images_per_prompt , 1 )
2024-04-29 16:20:44 +02:00
prompt_embeds = prompt_embeds . view (
bs_embed * num_images_per_prompt , seq_len , - 1
)
2024-04-24 14:22:29 +02:00
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None :
uncond_tokensA : List [ str ]
uncond_tokensB : List [ str ]
if negative_prompt is None :
uncond_tokensA = [ " " ] * batch_size
uncond_tokensB = [ " " ] * batch_size
elif prompt is not None and type ( prompt ) is not type ( negative_prompt ) :
raise TypeError (
f " `negative_prompt` should be the same type to `prompt`, but got { type ( negative_prompt ) } != "
f " { type ( prompt ) } . "
)
elif isinstance ( negative_prompt , str ) :
uncond_tokensA = [ negative_promptA ]
uncond_tokensB = [ negative_promptB ]
elif batch_size != len ( negative_prompt ) :
raise ValueError (
f " `negative_prompt`: { negative_prompt } has batch size { len ( negative_prompt ) } , but `prompt`: "
f " { prompt } has batch size { batch_size } . Please make sure that passed `negative_prompt` matches "
" the batch size of `prompt`. "
)
else :
uncond_tokensA = negative_promptA
uncond_tokensB = negative_promptB
# textual inversion: procecss multi-vector tokens if necessary
if isinstance ( self , TextualInversionLoaderMixin ) :
2024-04-29 16:20:44 +02:00
uncond_tokensA = self . maybe_convert_prompt (
uncond_tokensA , self . tokenizer
)
uncond_tokensB = self . maybe_convert_prompt (
uncond_tokensB , self . tokenizer
)
2024-04-24 14:22:29 +02:00
max_length = prompt_embeds . shape [ 1 ]
uncond_inputA = self . tokenizer (
uncond_tokensA ,
padding = " max_length " ,
max_length = max_length ,
truncation = True ,
return_tensors = " pt " ,
)
uncond_inputB = self . tokenizer (
uncond_tokensB ,
padding = " max_length " ,
max_length = max_length ,
truncation = True ,
return_tensors = " pt " ,
)
2024-04-29 16:20:44 +02:00
if (
hasattr ( self . text_encoder_brushnet . config , " use_attention_mask " )
and self . text_encoder_brushnet . config . use_attention_mask
) :
2024-04-24 14:22:29 +02:00
attention_mask = uncond_inputA . attention_mask . to ( device )
else :
attention_mask = None
negative_prompt_embedsA = self . text_encoder_brushnet (
uncond_inputA . input_ids . to ( device ) ,
attention_mask = attention_mask ,
)
negative_prompt_embedsB = self . text_encoder_brushnet (
uncond_inputB . input_ids . to ( device ) ,
attention_mask = attention_mask ,
)
2024-04-29 16:20:44 +02:00
negative_prompt_embeds = (
negative_prompt_embedsA [ 0 ] * ( t_nag )
+ ( 1 - t_nag ) * negative_prompt_embedsB [ 0 ]
)
2024-04-24 14:22:29 +02:00
# negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance :
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds . shape [ 1 ]
2024-04-29 16:20:44 +02:00
negative_prompt_embeds = negative_prompt_embeds . to (
dtype = prompt_embeds_dtype , device = device
)
2024-04-24 14:22:29 +02:00
2024-04-29 16:20:44 +02:00
negative_prompt_embeds = negative_prompt_embeds . repeat (
1 , num_images_per_prompt , 1
)
negative_prompt_embeds = negative_prompt_embeds . view (
batch_size * num_images_per_prompt , seq_len , - 1
)
2024-04-24 14:22:29 +02:00
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
# print("prompt_embeds: ",prompt_embeds)
prompt_embeds = torch . cat ( [ negative_prompt_embeds , prompt_embeds ] )
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
def encode_prompt (
2024-04-29 16:20:44 +02:00
self ,
prompt ,
device ,
num_images_per_prompt ,
do_classifier_free_guidance ,
negative_prompt = None ,
prompt_embeds : Optional [ torch . FloatTensor ] = None ,
negative_prompt_embeds : Optional [ torch . FloatTensor ] = None ,
lora_scale : Optional [ float ] = None ,
clip_skip : Optional [ int ] = None ,
2024-04-24 14:22:29 +02:00
) :
r """
Encodes the prompt into text encoder hidden states .
Args :
prompt ( ` str ` or ` List [ str ] ` , * optional * ) :
prompt to be encoded
device : ( ` torch . device ` ) :
torch device
num_images_per_prompt ( ` int ` ) :
number of images that should be generated per prompt
do_classifier_free_guidance ( ` bool ` ) :
whether to use classifier free guidance or not
negative_prompt ( ` str ` or ` List [ str ] ` , * optional * ) :
The prompt or prompts not to guide the image generation . If not defined , one has to pass
` negative_prompt_embeds ` instead . Ignored when not using guidance ( i . e . , ignored if ` guidance_scale ` is
less than ` 1 ` ) .
prompt_embeds ( ` torch . FloatTensor ` , * optional * ) :
Pre - generated text embeddings . Can be used to easily tweak text inputs , * e . g . * prompt weighting . If not
provided , text embeddings will be generated from ` prompt ` input argument .
negative_prompt_embeds ( ` torch . FloatTensor ` , * optional * ) :
Pre - generated negative text embeddings . Can be used to easily tweak text inputs , * e . g . * prompt
weighting . If not provided , negative_prompt_embeds will be generated from ` negative_prompt ` input
argument .
lora_scale ( ` float ` , * optional * ) :
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded .
clip_skip ( ` int ` , * optional * ) :
Number of layers to be skipped from CLIP while computing the prompt embeddings . A value of 1 means that
the output of the pre - final layer will be used for computing the prompt embeddings .
"""
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
# print('1 ',prompt,negative_prompt)
if lora_scale is not None and isinstance ( self , LoraLoaderMixin ) :
self . _lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND :
adjust_lora_scale_text_encoder ( self . text_encoder , lora_scale )
else :
scale_lora_layers ( self . text_encoder , lora_scale )
# print('2 ',prompt,negative_prompt)
if prompt is not None and isinstance ( prompt , str ) :
batch_size = 1
elif prompt is not None and isinstance ( prompt , list ) :
batch_size = len ( prompt )
else :
batch_size = prompt_embeds . shape [ 0 ]
# print('3 ',prompt,negative_prompt)
if prompt_embeds is None :
# textual inversion: process multi-vector tokens if necessary
# print('4 ',prompt,negative_prompt)
if isinstance ( self , TextualInversionLoaderMixin ) :
prompt = self . maybe_convert_prompt ( prompt , self . tokenizer )
# print('5 ',prompt,negative_prompt)
text_inputs = self . tokenizer (
prompt ,
padding = " max_length " ,
max_length = self . tokenizer . model_max_length ,
truncation = True ,
return_tensors = " pt " ,
)
text_input_ids = text_inputs . input_ids
# print(prompt, text_input_ids)
2024-04-29 16:20:44 +02:00
untruncated_ids = self . tokenizer (
prompt , padding = " longest " , return_tensors = " pt "
) . input_ids
2024-04-24 14:22:29 +02:00
2024-04-29 16:20:44 +02:00
if untruncated_ids . shape [ - 1 ] > = text_input_ids . shape [
- 1
] and not torch . equal ( text_input_ids , untruncated_ids ) :
2024-04-24 14:22:29 +02:00
removed_text = self . tokenizer . batch_decode (
2024-04-29 16:20:44 +02:00
untruncated_ids [ : , self . tokenizer . model_max_length - 1 : - 1 ]
2024-04-24 14:22:29 +02:00
)
logger . warning (
" The following part of your input was truncated because CLIP can only handle sequences up to "
f " { self . tokenizer . model_max_length } tokens: { removed_text } "
)
2024-04-29 16:20:44 +02:00
if (
hasattr ( self . text_encoder . config , " use_attention_mask " )
and self . text_encoder . config . use_attention_mask
) :
2024-04-24 14:22:29 +02:00
attention_mask = text_inputs . attention_mask . to ( device )
else :
attention_mask = None
if clip_skip is None :
2024-04-29 16:20:44 +02:00
prompt_embeds = self . text_encoder (
text_input_ids . to ( device ) , attention_mask = attention_mask
)
2024-04-24 14:22:29 +02:00
prompt_embeds = prompt_embeds [ 0 ]
else :
prompt_embeds = self . text_encoder (
2024-04-29 16:20:44 +02:00
text_input_ids . to ( device ) ,
attention_mask = attention_mask ,
output_hidden_states = True ,
2024-04-24 14:22:29 +02:00
)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_embeds [ - 1 ] [ - ( clip_skip + 1 ) ]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
2024-04-29 16:20:44 +02:00
prompt_embeds = self . text_encoder . text_model . final_layer_norm (
prompt_embeds
)
2024-04-24 14:22:29 +02:00
if self . text_encoder is not None :
prompt_embeds_dtype = self . text_encoder . dtype
elif self . unet is not None :
prompt_embeds_dtype = self . unet . dtype
else :
prompt_embeds_dtype = prompt_embeds . dtype
prompt_embeds = prompt_embeds . to ( dtype = prompt_embeds_dtype , device = device )
bs_embed , seq_len , _ = prompt_embeds . shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds . repeat ( 1 , num_images_per_prompt , 1 )
2024-04-29 16:20:44 +02:00
prompt_embeds = prompt_embeds . view (
bs_embed * num_images_per_prompt , seq_len , - 1
)
2024-04-24 14:22:29 +02:00
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None :
uncond_tokens : List [ str ]
if negative_prompt is None :
uncond_tokens = [ " " ] * batch_size
elif prompt is not None and type ( prompt ) is not type ( negative_prompt ) :
raise TypeError (
f " `negative_prompt` should be the same type to `prompt`, but got { type ( negative_prompt ) } != "
f " { type ( prompt ) } . "
)
elif isinstance ( negative_prompt , str ) :
uncond_tokens = [ negative_prompt ]
elif batch_size != len ( negative_prompt ) :
raise ValueError (
f " `negative_prompt`: { negative_prompt } has batch size { len ( negative_prompt ) } , but `prompt`: "
f " { prompt } has batch size { batch_size } . Please make sure that passed `negative_prompt` matches "
" the batch size of `prompt`. "
)
else :
uncond_tokens = negative_prompt
# textual inversion: process multi-vector tokens if necessary
if isinstance ( self , TextualInversionLoaderMixin ) :
uncond_tokens = self . maybe_convert_prompt ( uncond_tokens , self . tokenizer )
max_length = prompt_embeds . shape [ 1 ]
uncond_input = self . tokenizer (
uncond_tokens ,
padding = " max_length " ,
max_length = max_length ,
truncation = True ,
return_tensors = " pt " ,
)
# print("neg: ", uncond_input.input_ids)
2024-04-29 16:20:44 +02:00
if (
hasattr ( self . text_encoder . config , " use_attention_mask " )
and self . text_encoder . config . use_attention_mask
) :
2024-04-24 14:22:29 +02:00
attention_mask = uncond_input . attention_mask . to ( device )
else :
attention_mask = None
negative_prompt_embeds = self . text_encoder (
uncond_input . input_ids . to ( device ) ,
attention_mask = attention_mask ,
)
negative_prompt_embeds = negative_prompt_embeds [ 0 ]
if do_classifier_free_guidance :
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds . shape [ 1 ]
2024-04-29 16:20:44 +02:00
negative_prompt_embeds = negative_prompt_embeds . to (
dtype = prompt_embeds_dtype , device = device
)
2024-04-24 14:22:29 +02:00
2024-04-29 16:20:44 +02:00
negative_prompt_embeds = negative_prompt_embeds . repeat (
1 , num_images_per_prompt , 1
)
negative_prompt_embeds = negative_prompt_embeds . view (
batch_size * num_images_per_prompt , seq_len , - 1
)
2024-04-24 14:22:29 +02:00
if isinstance ( self , LoraLoaderMixin ) and USE_PEFT_BACKEND :
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers ( self . text_encoder , lora_scale )
prompt_embeds = torch . cat ( [ negative_prompt_embeds , prompt_embeds ] )
return prompt_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
2024-04-29 16:20:44 +02:00
def encode_image (
self , image , device , num_images_per_prompt , output_hidden_states = None
) :
2024-04-24 14:22:29 +02:00
dtype = next ( self . image_encoder . parameters ( ) ) . dtype
if not isinstance ( image , torch . Tensor ) :
image = self . feature_extractor ( image , return_tensors = " pt " ) . pixel_values
image = image . to ( device = device , dtype = dtype )
if output_hidden_states :
2024-04-29 16:20:44 +02:00
image_enc_hidden_states = self . image_encoder (
image , output_hidden_states = True
) . hidden_states [ - 2 ]
image_enc_hidden_states = image_enc_hidden_states . repeat_interleave (
num_images_per_prompt , dim = 0
)
2024-04-24 14:22:29 +02:00
uncond_image_enc_hidden_states = self . image_encoder (
torch . zeros_like ( image ) , output_hidden_states = True
) . hidden_states [ - 2 ]
2024-04-29 16:20:44 +02:00
uncond_image_enc_hidden_states = (
uncond_image_enc_hidden_states . repeat_interleave (
num_images_per_prompt , dim = 0
)
2024-04-24 14:22:29 +02:00
)
return image_enc_hidden_states , uncond_image_enc_hidden_states
else :
image_embeds = self . image_encoder ( image ) . image_embeds
image_embeds = image_embeds . repeat_interleave ( num_images_per_prompt , dim = 0 )
uncond_image_embeds = torch . zeros_like ( image_embeds )
return image_embeds , uncond_image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
def prepare_ip_adapter_image_embeds (
2024-04-29 16:20:44 +02:00
self ,
ip_adapter_image ,
ip_adapter_image_embeds ,
device ,
num_images_per_prompt ,
do_classifier_free_guidance ,
2024-04-24 14:22:29 +02:00
) :
if ip_adapter_image_embeds is None :
if not isinstance ( ip_adapter_image , list ) :
ip_adapter_image = [ ip_adapter_image ]
2024-04-29 16:20:44 +02:00
if len ( ip_adapter_image ) != len (
self . unet . encoder_hid_proj . image_projection_layers
) :
2024-04-24 14:22:29 +02:00
raise ValueError (
f " `ip_adapter_image` must have same length as the number of IP Adapters. Got { len ( ip_adapter_image ) } images and { len ( self . unet . encoder_hid_proj . image_projection_layers ) } IP Adapters. "
)
image_embeds = [ ]
for single_ip_adapter_image , image_proj_layer in zip (
2024-04-29 16:20:44 +02:00
ip_adapter_image , self . unet . encoder_hid_proj . image_projection_layers
2024-04-24 14:22:29 +02:00
) :
output_hidden_state = not isinstance ( image_proj_layer , ImageProjection )
single_image_embeds , single_negative_image_embeds = self . encode_image (
single_ip_adapter_image , device , 1 , output_hidden_state
)
2024-04-29 16:20:44 +02:00
single_image_embeds = torch . stack (
[ single_image_embeds ] * num_images_per_prompt , dim = 0
)
2024-04-24 14:22:29 +02:00
single_negative_image_embeds = torch . stack (
[ single_negative_image_embeds ] * num_images_per_prompt , dim = 0
)
if do_classifier_free_guidance :
2024-04-29 16:20:44 +02:00
single_image_embeds = torch . cat (
[ single_negative_image_embeds , single_image_embeds ]
)
2024-04-24 14:22:29 +02:00
single_image_embeds = single_image_embeds . to ( device )
image_embeds . append ( single_image_embeds )
else :
repeat_dims = [ 1 ]
image_embeds = [ ]
for single_image_embeds in ip_adapter_image_embeds :
if do_classifier_free_guidance :
2024-04-29 16:20:44 +02:00
single_negative_image_embeds , single_image_embeds = (
single_image_embeds . chunk ( 2 )
)
2024-04-24 14:22:29 +02:00
single_image_embeds = single_image_embeds . repeat (
2024-04-29 16:20:44 +02:00
num_images_per_prompt ,
* ( repeat_dims * len ( single_image_embeds . shape [ 1 : ] ) ) ,
2024-04-24 14:22:29 +02:00
)
single_negative_image_embeds = single_negative_image_embeds . repeat (
2024-04-29 16:20:44 +02:00
num_images_per_prompt ,
* ( repeat_dims * len ( single_negative_image_embeds . shape [ 1 : ] ) ) ,
)
single_image_embeds = torch . cat (
[ single_negative_image_embeds , single_image_embeds ]
2024-04-24 14:22:29 +02:00
)
else :
single_image_embeds = single_image_embeds . repeat (
2024-04-29 16:20:44 +02:00
num_images_per_prompt ,
* ( repeat_dims * len ( single_image_embeds . shape [ 1 : ] ) ) ,
2024-04-24 14:22:29 +02:00
)
image_embeds . append ( single_image_embeds )
return image_embeds
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker ( self , image , device , dtype ) :
if self . safety_checker is None :
has_nsfw_concept = None
else :
if torch . is_tensor ( image ) :
2024-04-29 16:20:44 +02:00
feature_extractor_input = self . image_processor . postprocess (
image , output_type = " pil "
)
2024-04-24 14:22:29 +02:00
else :
feature_extractor_input = self . image_processor . numpy_to_pil ( image )
2024-04-29 16:20:44 +02:00
safety_checker_input = self . feature_extractor (
feature_extractor_input , return_tensors = " pt "
) . to ( device )
2024-04-24 14:22:29 +02:00
image , has_nsfw_concept = self . safety_checker (
images = image , clip_input = safety_checker_input . pixel_values . to ( dtype )
)
return image , has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
def decode_latents ( self , latents ) :
deprecation_message = " The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead "
deprecate ( " decode_latents " , " 1.0.0 " , deprecation_message , standard_warn = False )
latents = 1 / self . vae . config . scaling_factor * latents
image = self . vae . decode ( latents , return_dict = False ) [ 0 ]
image = ( image / 2 + 0.5 ) . clamp ( 0 , 1 )
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image . cpu ( ) . permute ( 0 , 2 , 3 , 1 ) . float ( ) . numpy ( )
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs ( self , generator , eta ) :
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
2024-04-29 16:20:44 +02:00
accepts_eta = " eta " in set (
inspect . signature ( self . scheduler . step ) . parameters . keys ( )
)
2024-04-24 14:22:29 +02:00
extra_step_kwargs = { }
if accepts_eta :
extra_step_kwargs [ " eta " ] = eta
# check if the scheduler accepts generator
2024-04-29 16:20:44 +02:00
accepts_generator = " generator " in set (
inspect . signature ( self . scheduler . step ) . parameters . keys ( )
)
2024-04-24 14:22:29 +02:00
if accepts_generator :
extra_step_kwargs [ " generator " ] = generator
return extra_step_kwargs
def check_inputs (
2024-04-29 16:20:44 +02:00
self ,
prompt ,
image ,
mask ,
callback_steps ,
negative_prompt = None ,
prompt_embeds = None ,
negative_prompt_embeds = None ,
ip_adapter_image = None ,
ip_adapter_image_embeds = None ,
brushnet_conditioning_scale = 1.0 ,
control_guidance_start = 0.0 ,
control_guidance_end = 1.0 ,
callback_on_step_end_tensor_inputs = None ,
2024-04-24 14:22:29 +02:00
) :
2024-04-29 16:20:44 +02:00
if callback_steps is not None and (
not isinstance ( callback_steps , int ) or callback_steps < = 0
) :
2024-04-24 14:22:29 +02:00
raise ValueError (
f " `callback_steps` has to be a positive integer but is { callback_steps } of type "
f " { type ( callback_steps ) } . "
)
if callback_on_step_end_tensor_inputs is not None and not all (
2024-04-29 16:20:44 +02:00
k in self . _callback_tensor_inputs
for k in callback_on_step_end_tensor_inputs
2024-04-24 14:22:29 +02:00
) :
raise ValueError (
f " `callback_on_step_end_tensor_inputs` has to be in { self . _callback_tensor_inputs } , but found { [ k for k in callback_on_step_end_tensor_inputs if k not in self . _callback_tensor_inputs ] } "
)
if prompt is not None and prompt_embeds is not None :
raise ValueError (
f " Cannot forward both `prompt`: { prompt } and `prompt_embeds`: { prompt_embeds } . Please make sure to "
" only forward one of the two. "
)
elif prompt is None and prompt_embeds is None :
raise ValueError (
" Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined. "
)
2024-04-29 16:20:44 +02:00
elif prompt is not None and (
not isinstance ( prompt , str ) and not isinstance ( prompt , list )
) :
raise ValueError (
f " `prompt` has to be of type `str` or `list` but is { type ( prompt ) } "
)
2024-04-24 14:22:29 +02:00
if negative_prompt is not None and negative_prompt_embeds is not None :
raise ValueError (
f " Cannot forward both `negative_prompt`: { negative_prompt } and `negative_prompt_embeds`: "
f " { negative_prompt_embeds } . Please make sure to only forward one of the two. "
)
if prompt_embeds is not None and negative_prompt_embeds is not None :
if prompt_embeds . shape != negative_prompt_embeds . shape :
raise ValueError (
" `prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but "
f " got: `prompt_embeds` { prompt_embeds . shape } != `negative_prompt_embeds` "
f " { negative_prompt_embeds . shape } . "
)
# Check `image`
is_compiled = hasattr ( F , " scaled_dot_product_attention " ) and isinstance (
self . brushnet , torch . _dynamo . eval_frame . OptimizedModule
)
if (
2024-04-29 16:20:44 +02:00
isinstance ( self . brushnet , BrushNetModel )
or is_compiled
and isinstance ( self . brushnet . _orig_mod , BrushNetModel )
2024-04-24 14:22:29 +02:00
) :
self . check_image ( image , mask , prompt , prompt_embeds )
else :
assert False
# Check `brushnet_conditioning_scale`
if (
2024-04-29 16:20:44 +02:00
isinstance ( self . brushnet , BrushNetModel )
or is_compiled
and isinstance ( self . brushnet . _orig_mod , BrushNetModel )
2024-04-24 14:22:29 +02:00
) :
if not isinstance ( brushnet_conditioning_scale , float ) :
2024-04-29 16:20:44 +02:00
raise TypeError (
" For single brushnet: `brushnet_conditioning_scale` must be type `float`. "
)
2024-04-24 14:22:29 +02:00
else :
assert False
if not isinstance ( control_guidance_start , ( tuple , list ) ) :
control_guidance_start = [ control_guidance_start ]
if not isinstance ( control_guidance_end , ( tuple , list ) ) :
control_guidance_end = [ control_guidance_end ]
if len ( control_guidance_start ) != len ( control_guidance_end ) :
raise ValueError (
f " `control_guidance_start` has { len ( control_guidance_start ) } elements, but `control_guidance_end` has { len ( control_guidance_end ) } elements. Make sure to provide the same number of elements to each list. "
)
for start , end in zip ( control_guidance_start , control_guidance_end ) :
if start > = end :
raise ValueError (
f " control guidance start: { start } cannot be larger or equal to control guidance end: { end } . "
)
if start < 0.0 :
2024-04-29 16:20:44 +02:00
raise ValueError (
f " control guidance start: { start } can ' t be smaller than 0. "
)
2024-04-24 14:22:29 +02:00
if end > 1.0 :
2024-04-29 16:20:44 +02:00
raise ValueError (
f " control guidance end: { end } can ' t be larger than 1.0. "
)
2024-04-24 14:22:29 +02:00
if ip_adapter_image is not None and ip_adapter_image_embeds is not None :
raise ValueError (
" Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined. "
)
if ip_adapter_image_embeds is not None :
if not isinstance ( ip_adapter_image_embeds , list ) :
raise ValueError (
f " `ip_adapter_image_embeds` has to be of type `list` but is { type ( ip_adapter_image_embeds ) } "
)
elif ip_adapter_image_embeds [ 0 ] . ndim not in [ 3 , 4 ] :
raise ValueError (
f " `ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is { ip_adapter_image_embeds [ 0 ] . ndim } D "
)
def check_image ( self , image , mask , prompt , prompt_embeds ) :
image_is_pil = isinstance ( image , PIL . Image . Image )
image_is_tensor = isinstance ( image , torch . Tensor )
image_is_np = isinstance ( image , np . ndarray )
2024-04-29 16:20:44 +02:00
image_is_pil_list = isinstance ( image , list ) and isinstance (
image [ 0 ] , PIL . Image . Image
)
image_is_tensor_list = isinstance ( image , list ) and isinstance (
image [ 0 ] , torch . Tensor
)
2024-04-24 14:22:29 +02:00
image_is_np_list = isinstance ( image , list ) and isinstance ( image [ 0 ] , np . ndarray )
if (
2024-04-29 16:20:44 +02:00
not image_is_pil
and not image_is_tensor
and not image_is_np
and not image_is_pil_list
and not image_is_tensor_list
and not image_is_np_list
2024-04-24 14:22:29 +02:00
) :
raise TypeError (
f " image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is { type ( image ) } "
)
mask_is_pil = isinstance ( mask , PIL . Image . Image )
mask_is_tensor = isinstance ( mask , torch . Tensor )
mask_is_np = isinstance ( mask , np . ndarray )
2024-04-29 16:20:44 +02:00
mask_is_pil_list = isinstance ( mask , list ) and isinstance (
mask [ 0 ] , PIL . Image . Image
)
mask_is_tensor_list = isinstance ( mask , list ) and isinstance (
mask [ 0 ] , torch . Tensor
)
2024-04-24 14:22:29 +02:00
mask_is_np_list = isinstance ( mask , list ) and isinstance ( mask [ 0 ] , np . ndarray )
if (
2024-04-29 16:20:44 +02:00
not mask_is_pil
and not mask_is_tensor
and not mask_is_np
and not mask_is_pil_list
and not mask_is_tensor_list
and not mask_is_np_list
2024-04-24 14:22:29 +02:00
) :
raise TypeError (
f " mask must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is { type ( mask ) } "
)
if image_is_pil :
image_batch_size = 1
else :
image_batch_size = len ( image )
if prompt is not None and isinstance ( prompt , str ) :
prompt_batch_size = 1
elif prompt is not None and isinstance ( prompt , list ) :
prompt_batch_size = len ( prompt )
elif prompt_embeds is not None :
prompt_batch_size = prompt_embeds . shape [ 0 ]
if image_batch_size != 1 and image_batch_size != prompt_batch_size :
raise ValueError (
f " If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: { image_batch_size } , prompt batch size: { prompt_batch_size } "
)
def prepare_image (
2024-04-29 16:20:44 +02:00
self ,
image ,
width ,
height ,
batch_size ,
num_images_per_prompt ,
device ,
dtype ,
do_classifier_free_guidance = False ,
guess_mode = False ,
2024-04-24 14:22:29 +02:00
) :
2024-04-29 16:20:44 +02:00
image = self . image_processor . preprocess ( image , height = height , width = width ) . to (
dtype = torch . float32
)
2024-04-24 14:22:29 +02:00
image_batch_size = image . shape [ 0 ]
if image_batch_size == 1 :
repeat_by = batch_size
else :
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image . repeat_interleave ( repeat_by , dim = 0 )
image = image . to ( device = device , dtype = dtype )
if do_classifier_free_guidance and not guess_mode :
image = torch . cat ( [ image ] * 2 )
return image . to ( device = device , dtype = dtype )
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
2024-04-29 16:20:44 +02:00
def prepare_latents (
self ,
batch_size ,
num_channels_latents ,
height ,
width ,
dtype ,
device ,
generator ,
latents = None ,
) :
shape = (
batch_size ,
num_channels_latents ,
height / / self . vae_scale_factor ,
width / / self . vae_scale_factor ,
)
2024-04-24 14:22:29 +02:00
if isinstance ( generator , list ) and len ( generator ) != batch_size :
raise ValueError (
f " You have passed a list of generators of length { len ( generator ) } , but requested an effective batch "
f " size of { batch_size } . Make sure the batch size matches the length of the generators. "
)
if latents is None :
noise = randn_tensor ( shape , generator = generator , device = device , dtype = dtype )
else :
noise = latents . to ( device )
# scale the initial noise by the standard deviation required by the scheduler
latents = noise * self . scheduler . init_noise_sigma
return latents , noise
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding ( self , w , embedding_dim = 512 , dtype = torch . float32 ) :
"""
See https : / / github . com / google - research / vdm / blob / dc27b98a554f65cdc654b800da5aa1846545d41b / model_vdm . py #L298
Args :
timesteps ( ` torch . Tensor ` ) :
generate embedding vectors at these timesteps
embedding_dim ( ` int ` , * optional * , defaults to 512 ) :
dimension of the embeddings to generate
dtype :
data type of the generated embeddings
Returns :
` torch . FloatTensor ` : Embedding vectors with shape ` ( len ( timesteps ) , embedding_dim ) `
"""
assert len ( w . shape ) == 1
w = w * 1000.0
half_dim = embedding_dim / / 2
emb = torch . log ( torch . tensor ( 10000.0 ) ) / ( half_dim - 1 )
emb = torch . exp ( torch . arange ( half_dim , dtype = dtype ) * - emb )
emb = w . to ( dtype ) [ : , None ] * emb [ None , : ]
emb = torch . cat ( [ torch . sin ( emb ) , torch . cos ( emb ) ] , dim = 1 )
if embedding_dim % 2 == 1 : # zero pad
emb = torch . nn . functional . pad ( emb , ( 0 , 1 ) )
assert emb . shape == ( w . shape [ 0 ] , embedding_dim )
return emb
@property
def guidance_scale ( self ) :
return self . _guidance_scale
@property
def clip_skip ( self ) :
return self . _clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance ( self ) :
return self . _guidance_scale > 1 and self . unet . config . time_cond_proj_dim is None
@property
def cross_attention_kwargs ( self ) :
return self . _cross_attention_kwargs
@property
def num_timesteps ( self ) :
return self . _num_timesteps
@torch.no_grad ( )
@replace_example_docstring ( EXAMPLE_DOC_STRING )
def __call__ (
2024-04-29 16:20:44 +02:00
self ,
promptA : Union [ str , List [ str ] ] = None ,
promptB : Union [ str , List [ str ] ] = None ,
promptU : Union [ str , List [ str ] ] = None ,
tradoff : float = 1.0 ,
tradoff_nag : float = 1.0 ,
image : PipelineImageInput = None ,
mask : PipelineImageInput = None ,
height : Optional [ int ] = None ,
width : Optional [ int ] = None ,
num_inference_steps : int = 50 ,
timesteps : List [ int ] = None ,
guidance_scale : float = 7.5 ,
negative_promptA : Optional [ Union [ str , List [ str ] ] ] = None ,
negative_promptB : Optional [ Union [ str , List [ str ] ] ] = None ,
negative_promptU : Optional [ Union [ str , List [ str ] ] ] = None ,
num_images_per_prompt : Optional [ int ] = 1 ,
eta : float = 0.0 ,
generator : Optional [ Union [ torch . Generator , List [ torch . Generator ] ] ] = None ,
latents : Optional [ torch . FloatTensor ] = None ,
prompt_embeds : Optional [ torch . FloatTensor ] = None ,
negative_prompt_embeds : Optional [ torch . FloatTensor ] = None ,
ip_adapter_image : Optional [ PipelineImageInput ] = None ,
ip_adapter_image_embeds : Optional [ List [ torch . FloatTensor ] ] = None ,
output_type : Optional [ str ] = " pil " ,
return_dict : bool = True ,
cross_attention_kwargs : Optional [ Dict [ str , Any ] ] = None ,
brushnet_conditioning_scale : Union [ float , List [ float ] ] = 1.0 ,
guess_mode : bool = False ,
control_guidance_start : Union [ float , List [ float ] ] = 0.0 ,
control_guidance_end : Union [ float , List [ float ] ] = 1.0 ,
clip_skip : Optional [ int ] = None ,
callback_on_step_end : Optional [ Callable [ [ int , int , Dict ] , None ] ] = None ,
callback_on_step_end_tensor_inputs : List [ str ] = [ " latents " ] ,
* * kwargs ,
2024-04-24 14:22:29 +02:00
) :
r """
The call function to the pipeline for generation .
Args :
prompt ( ` str ` or ` List [ str ] ` , * optional * ) :
The prompt or prompts to guide image generation . If not defined , you need to pass ` prompt_embeds ` .
image ( ` torch . FloatTensor ` , ` PIL . Image . Image ` , ` np . ndarray ` , ` List [ torch . FloatTensor ] ` , ` List [ PIL . Image . Image ] ` , ` List [ np . ndarray ] ` , :
` List [ List [ torch . FloatTensor ] ] ` , ` List [ List [ np . ndarray ] ] ` or ` List [ List [ PIL . Image . Image ] ] ` ) :
The BrushNet input condition to provide guidance to the ` unet ` for generation . If the type is
specified as ` torch . FloatTensor ` , it is passed to BrushNet as is . ` PIL . Image . Image ` can also be
accepted as an image . The dimensions of the output image defaults to ` image ` ' s dimensions. If height
and / or width are passed , ` image ` is resized accordingly . If multiple BrushNets are specified in
` init ` , images must be passed as a list such that each element of the list can be correctly batched for
input to a single BrushNet . When ` prompt ` is a list , and if a list of images is passed for a single BrushNet ,
each will be paired with each prompt in the ` prompt ` list . This also applies to multiple BrushNets ,
where a list of image lists can be passed to batch for each prompt and each BrushNet .
mask ( ` torch . FloatTensor ` , ` PIL . Image . Image ` , ` np . ndarray ` , ` List [ torch . FloatTensor ] ` , ` List [ PIL . Image . Image ] ` , ` List [ np . ndarray ] ` , :
` List [ List [ torch . FloatTensor ] ] ` , ` List [ List [ np . ndarray ] ] ` or ` List [ List [ PIL . Image . Image ] ] ` ) :
The BrushNet input condition to provide guidance to the ` unet ` for generation . If the type is
specified as ` torch . FloatTensor ` , it is passed to BrushNet as is . ` PIL . Image . Image ` can also be
accepted as an image . The dimensions of the output image defaults to ` image ` ' s dimensions. If height
and / or width are passed , ` image ` is resized accordingly . If multiple BrushNets are specified in
` init ` , images must be passed as a list such that each element of the list can be correctly batched for
input to a single BrushNet . When ` prompt ` is a list , and if a list of images is passed for a single BrushNet ,
each will be paired with each prompt in the ` prompt ` list . This also applies to multiple BrushNets ,
where a list of image lists can be passed to batch for each prompt and each BrushNet .
height ( ` int ` , * optional * , defaults to ` self . unet . config . sample_size * self . vae_scale_factor ` ) :
The height in pixels of the generated image .
width ( ` int ` , * optional * , defaults to ` self . unet . config . sample_size * self . vae_scale_factor ` ) :
The width in pixels of the generated image .
num_inference_steps ( ` int ` , * optional * , defaults to 50 ) :
The number of denoising steps . More denoising steps usually lead to a higher quality image at the
expense of slower inference .
timesteps ( ` List [ int ] ` , * optional * ) :
Custom timesteps to use for the denoising process with schedulers which support a ` timesteps ` argument
in their ` set_timesteps ` method . If not defined , the default behavior when ` num_inference_steps ` is
passed will be used . Must be in descending order .
guidance_scale ( ` float ` , * optional * , defaults to 7.5 ) :
A higher guidance scale value encourages the model to generate images closely linked to the text
` prompt ` at the expense of lower image quality . Guidance scale is enabled when ` guidance_scale > 1 ` .
negative_prompt ( ` str ` or ` List [ str ] ` , * optional * ) :
The prompt or prompts to guide what to not include in image generation . If not defined , you need to
pass ` negative_prompt_embeds ` instead . Ignored when not using guidance ( ` guidance_scale < 1 ` ) .
num_images_per_prompt ( ` int ` , * optional * , defaults to 1 ) :
The number of images to generate per prompt .
eta ( ` float ` , * optional * , defaults to 0.0 ) :
Corresponds to parameter eta ( η ) from the [ DDIM ] ( https : / / arxiv . org / abs / 2010.02502 ) paper . Only applies
to the [ ` ~ schedulers . DDIMScheduler ` ] , and is ignored in other schedulers .
generator ( ` torch . Generator ` or ` List [ torch . Generator ] ` , * optional * ) :
A [ ` torch . Generator ` ] ( https : / / pytorch . org / docs / stable / generated / torch . Generator . html ) to make
generation deterministic .
latents ( ` torch . FloatTensor ` , * optional * ) :
Pre - generated noisy latents sampled from a Gaussian distribution , to be used as inputs for image
generation . Can be used to tweak the same generation with different prompts . If not provided , a latents
tensor is generated by sampling using the supplied random ` generator ` .
prompt_embeds ( ` torch . FloatTensor ` , * optional * ) :
Pre - generated text embeddings . Can be used to easily tweak text inputs ( prompt weighting ) . If not
provided , text embeddings are generated from the ` prompt ` input argument .
negative_prompt_embeds ( ` torch . FloatTensor ` , * optional * ) :
Pre - generated negative text embeddings . Can be used to easily tweak text inputs ( prompt weighting ) . If
not provided , ` negative_prompt_embeds ` are generated from the ` negative_prompt ` input argument .
ip_adapter_image : ( ` PipelineImageInput ` , * optional * ) : Optional image input to work with IP Adapters .
ip_adapter_image_embeds ( ` List [ torch . FloatTensor ] ` , * optional * ) :
Pre - generated image embeddings for IP - Adapter . It should be a list of length same as number of IP - adapters .
Each element should be a tensor of shape ` ( batch_size , num_images , emb_dim ) ` . It should contain the negative image embedding
if ` do_classifier_free_guidance ` is set to ` True ` .
If not provided , embeddings are computed from the ` ip_adapter_image ` input argument .
output_type ( ` str ` , * optional * , defaults to ` " pil " ` ) :
The output format of the generated image . Choose between ` PIL . Image ` or ` np . array ` .
return_dict ( ` bool ` , * optional * , defaults to ` True ` ) :
Whether or not to return a [ ` ~ pipelines . stable_diffusion . StableDiffusionPipelineOutput ` ] instead of a
plain tuple .
callback ( ` Callable ` , * optional * ) :
A function that calls every ` callback_steps ` steps during inference . The function is called with the
following arguments : ` callback ( step : int , timestep : int , latents : torch . FloatTensor ) ` .
callback_steps ( ` int ` , * optional * , defaults to 1 ) :
The frequency at which the ` callback ` function is called . If not specified , the callback is called at
every step .
cross_attention_kwargs ( ` dict ` , * optional * ) :
A kwargs dictionary that if specified is passed along to the [ ` AttentionProcessor ` ] as defined in
[ ` self . processor ` ] ( https : / / github . com / huggingface / diffusers / blob / main / src / diffusers / models / attention_processor . py ) .
brushnet_conditioning_scale ( ` float ` or ` List [ float ] ` , * optional * , defaults to 1.0 ) :
The outputs of the BrushNet are multiplied by ` brushnet_conditioning_scale ` before they are added
to the residual in the original ` unet ` . If multiple BrushNets are specified in ` init ` , you can set
the corresponding scale as a list .
guess_mode ( ` bool ` , * optional * , defaults to ` False ` ) :
The BrushNet encoder tries to recognize the content of the input image even if you remove all
prompts . A ` guidance_scale ` value between 3.0 and 5.0 is recommended .
control_guidance_start ( ` float ` or ` List [ float ] ` , * optional * , defaults to 0.0 ) :
The percentage of total steps at which the BrushNet starts applying .
control_guidance_end ( ` float ` or ` List [ float ] ` , * optional * , defaults to 1.0 ) :
The percentage of total steps at which the BrushNet stops applying .
clip_skip ( ` int ` , * optional * ) :
Number of layers to be skipped from CLIP while computing the prompt embeddings . A value of 1 means that
the output of the pre - final layer will be used for computing the prompt embeddings .
callback_on_step_end ( ` Callable ` , * optional * ) :
A function that calls at the end of each denoising steps during the inference . The function is called
with the following arguments : ` callback_on_step_end ( self : DiffusionPipeline , step : int , timestep : int ,
callback_kwargs : Dict ) ` . ` callback_kwargs ` will include a list of all tensors as specified by
` callback_on_step_end_tensor_inputs ` .
callback_on_step_end_tensor_inputs ( ` List ` , * optional * ) :
The list of tensor inputs for the ` callback_on_step_end ` function . The tensors specified in the list
will be passed as ` callback_kwargs ` argument . You will only be able to include variables listed in the
` . _callback_tensor_inputs ` attribute of your pipeine class .
Examples :
Returns :
[ ` ~ pipelines . stable_diffusion . StableDiffusionPipelineOutput ` ] or ` tuple ` :
If ` return_dict ` is ` True ` , [ ` ~ pipelines . stable_diffusion . StableDiffusionPipelineOutput ` ] is returned ,
otherwise a ` tuple ` is returned where the first element is a list with the generated images and the
second element is a list of ` bool ` s indicating whether the corresponding generated image contains
" not-safe-for-work " ( nsfw ) content .
"""
callback = kwargs . pop ( " callback " , None )
callback_steps = kwargs . pop ( " callback_steps " , None )
if callback is not None :
deprecate (
" callback " ,
" 1.0.0 " ,
" Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end` " ,
)
if callback_steps is not None :
deprecate (
" callback_steps " ,
" 1.0.0 " ,
" Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end` " ,
)
2024-04-29 16:20:44 +02:00
brushnet = (
self . brushnet . _orig_mod
if is_compiled_module ( self . brushnet )
else self . brushnet
)
2024-04-24 14:22:29 +02:00
# align format for control guidance
2024-04-29 16:20:44 +02:00
if not isinstance ( control_guidance_start , list ) and isinstance (
control_guidance_end , list
) :
control_guidance_start = len ( control_guidance_end ) * [
control_guidance_start
]
elif not isinstance ( control_guidance_end , list ) and isinstance (
control_guidance_start , list
) :
2024-04-24 14:22:29 +02:00
control_guidance_end = len ( control_guidance_start ) * [ control_guidance_end ]
2024-04-29 16:20:44 +02:00
elif not isinstance ( control_guidance_start , list ) and not isinstance (
control_guidance_end , list
) :
2024-04-24 14:22:29 +02:00
control_guidance_start , control_guidance_end = (
[ control_guidance_start ] ,
[ control_guidance_end ] ,
)
# 1. Check inputs. Raise error if not correct
prompt = promptA
negative_prompt = negative_promptA
self . check_inputs (
prompt ,
image ,
mask ,
callback_steps ,
negative_prompt ,
prompt_embeds ,
negative_prompt_embeds ,
ip_adapter_image ,
ip_adapter_image_embeds ,
brushnet_conditioning_scale ,
control_guidance_start ,
control_guidance_end ,
callback_on_step_end_tensor_inputs ,
)
self . _guidance_scale = guidance_scale
self . _clip_skip = clip_skip
self . _cross_attention_kwargs = cross_attention_kwargs
# 2. Define call parameters
if prompt is not None and isinstance ( prompt , str ) :
batch_size = 1
elif prompt is not None and isinstance ( prompt , list ) :
batch_size = len ( prompt )
else :
batch_size = prompt_embeds . shape [ 0 ]
device = self . _execution_device
global_pool_conditions = (
brushnet . config . global_pool_conditions
if isinstance ( brushnet , BrushNetModel )
else brushnet . nets [ 0 ] . config . global_pool_conditions
)
guess_mode = guess_mode or global_pool_conditions
# 3. Encode input prompt
text_encoder_lora_scale = (
2024-04-29 16:20:44 +02:00
self . cross_attention_kwargs . get ( " scale " , None )
if self . cross_attention_kwargs is not None
else None
2024-04-24 14:22:29 +02:00
)
prompt_embeds = self . _encode_prompt (
promptA ,
promptB ,
tradoff ,
device ,
num_images_per_prompt ,
self . do_classifier_free_guidance ,
negative_promptA ,
negative_promptB ,
tradoff_nag ,
prompt_embeds = prompt_embeds ,
negative_prompt_embeds = negative_prompt_embeds ,
lora_scale = text_encoder_lora_scale ,
)
prompt_embedsU = None
negative_prompt_embedsU = None
prompt_embedsU = self . encode_prompt (
promptU ,
device ,
num_images_per_prompt ,
self . do_classifier_free_guidance ,
negative_promptU ,
prompt_embeds = prompt_embedsU ,
negative_prompt_embeds = negative_prompt_embedsU ,
lora_scale = text_encoder_lora_scale ,
)
if ip_adapter_image is not None or ip_adapter_image_embeds is not None :
image_embeds = self . prepare_ip_adapter_image_embeds (
ip_adapter_image ,
ip_adapter_image_embeds ,
device ,
batch_size * num_images_per_prompt ,
self . do_classifier_free_guidance ,
)
# 4. Prepare image
if isinstance ( brushnet , BrushNetModel ) :
image = self . prepare_image (
image = image ,
width = width ,
height = height ,
batch_size = batch_size * num_images_per_prompt ,
num_images_per_prompt = num_images_per_prompt ,
device = device ,
dtype = brushnet . dtype ,
do_classifier_free_guidance = self . do_classifier_free_guidance ,
guess_mode = guess_mode ,
)
original_mask = self . prepare_image (
image = mask ,
width = width ,
height = height ,
batch_size = batch_size * num_images_per_prompt ,
num_images_per_prompt = num_images_per_prompt ,
device = device ,
dtype = brushnet . dtype ,
do_classifier_free_guidance = self . do_classifier_free_guidance ,
guess_mode = guess_mode ,
)
original_mask = ( original_mask . sum ( 1 ) [ : , None , : , : ] < 0 ) . to ( image . dtype )
height , width = image . shape [ - 2 : ]
else :
assert False
# 5. Prepare timesteps
2024-04-29 16:20:44 +02:00
timesteps , num_inference_steps = retrieve_timesteps (
self . scheduler , num_inference_steps , device , timesteps
)
2024-04-24 14:22:29 +02:00
self . _num_timesteps = len ( timesteps )
# 6. Prepare latent variables
num_channels_latents = self . unet . config . in_channels
latents , noise = self . prepare_latents (
batch_size * num_images_per_prompt ,
num_channels_latents ,
height ,
width ,
prompt_embeds . dtype ,
device ,
generator ,
latents ,
)
# 6.1 prepare condition latents
# mask_i = transforms.ToPILImage()(image[0:1,:,:,:].squeeze(0))
# mask_i.save('_mask.png')
# print(brushnet.dtype)
2024-04-29 16:20:44 +02:00
conditioning_latents = (
self . vae . encode (
image . to ( device = device , dtype = brushnet . dtype )
) . latent_dist . sample ( )
* self . vae . config . scaling_factor
)
2024-04-24 14:22:29 +02:00
mask = torch . nn . functional . interpolate (
original_mask ,
2024-04-29 16:20:44 +02:00
size = ( conditioning_latents . shape [ - 2 ] , conditioning_latents . shape [ - 1 ] ) ,
2024-04-24 14:22:29 +02:00
)
conditioning_latents = torch . concat ( [ conditioning_latents , mask ] , 1 )
# image = self.vae.decode(conditioning_latents[:1,:4,:,:] / self.vae.config.scaling_factor, return_dict=False, generator=generator)[0]
# from torchvision import transforms
# mask_i = transforms.ToPILImage()(image[0:1,:,:,:].squeeze(0)/2+0.5)
# mask_i.save(str(timesteps[0]) +'_C.png')
# 6.5 Optionally get Guidance Scale Embedding
timestep_cond = None
if self . unet . config . time_cond_proj_dim is not None :
2024-04-29 16:20:44 +02:00
guidance_scale_tensor = torch . tensor ( self . guidance_scale - 1 ) . repeat (
batch_size * num_images_per_prompt
)
2024-04-24 14:22:29 +02:00
timestep_cond = self . get_guidance_scale_embedding (
guidance_scale_tensor , embedding_dim = self . unet . config . time_cond_proj_dim
) . to ( device = device , dtype = latents . dtype )
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self . prepare_extra_step_kwargs ( generator , eta )
# 7.1 Add image embeds for IP-Adapter
added_cond_kwargs = (
{ " image_embeds " : image_embeds }
if ip_adapter_image is not None or ip_adapter_image_embeds is not None
else None
)
# 7.2 Create tensor stating which brushnets to keep
brushnet_keep = [ ]
for i in range ( len ( timesteps ) ) :
keeps = [
1.0 - float ( i / len ( timesteps ) < s or ( i + 1 ) / len ( timesteps ) > e )
for s , e in zip ( control_guidance_start , control_guidance_end )
]
2024-04-29 16:20:44 +02:00
brushnet_keep . append (
keeps [ 0 ] if isinstance ( brushnet , BrushNetModel ) else keeps
)
2024-04-24 14:22:29 +02:00
# 8. Denoising loop
num_warmup_steps = len ( timesteps ) - num_inference_steps * self . scheduler . order
is_unet_compiled = is_compiled_module ( self . unet )
is_brushnet_compiled = is_compiled_module ( self . brushnet )
is_torch_higher_equal_2_1 = is_torch_version ( " >= " , " 2.1 " )
with self . progress_bar ( total = num_inference_steps ) as progress_bar :
for i , t in enumerate ( timesteps ) :
# Relevant thread:
# https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
2024-04-29 16:20:44 +02:00
if (
is_unet_compiled and is_brushnet_compiled
) and is_torch_higher_equal_2_1 :
2024-04-24 14:22:29 +02:00
torch . _inductor . cudagraph_mark_step_begin ( )
# expand the latents if we are doing classifier free guidance
2024-04-29 16:20:44 +02:00
latent_model_input = (
torch . cat ( [ latents ] * 2 )
if self . do_classifier_free_guidance
else latents
)
latent_model_input = self . scheduler . scale_model_input (
latent_model_input , t
)
2024-04-24 14:22:29 +02:00
# brushnet(s) inference
if guess_mode and self . do_classifier_free_guidance :
# Infer BrushNet only for the conditional batch.
control_model_input = latents
2024-04-29 16:20:44 +02:00
control_model_input = self . scheduler . scale_model_input (
control_model_input , t
)
2024-04-24 14:22:29 +02:00
brushnet_prompt_embeds = prompt_embeds . chunk ( 2 ) [ 1 ]
else :
control_model_input = latent_model_input
brushnet_prompt_embeds = prompt_embeds
if isinstance ( brushnet_keep [ i ] , list ) :
2024-04-29 16:20:44 +02:00
cond_scale = [
c * s
for c , s in zip ( brushnet_conditioning_scale , brushnet_keep [ i ] )
]
2024-04-24 14:22:29 +02:00
else :
brushnet_cond_scale = brushnet_conditioning_scale
if isinstance ( brushnet_cond_scale , list ) :
brushnet_cond_scale = brushnet_cond_scale [ 0 ]
cond_scale = brushnet_cond_scale * brushnet_keep [ i ]
2024-04-29 16:20:44 +02:00
down_block_res_samples , mid_block_res_sample , up_block_res_samples = (
self . brushnet (
control_model_input ,
t ,
encoder_hidden_states = brushnet_prompt_embeds ,
brushnet_cond = conditioning_latents ,
conditioning_scale = cond_scale ,
guess_mode = guess_mode ,
return_dict = False ,
)
2024-04-24 14:22:29 +02:00
)
if guess_mode and self . do_classifier_free_guidance :
# Infered BrushNet only for the conditional batch.
# To apply the output of BrushNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
2024-04-29 16:20:44 +02:00
down_block_res_samples = [
torch . cat ( [ torch . zeros_like ( d ) , d ] )
for d in down_block_res_samples
]
mid_block_res_sample = torch . cat (
[ torch . zeros_like ( mid_block_res_sample ) , mid_block_res_sample ]
)
up_block_res_samples = [
torch . cat ( [ torch . zeros_like ( d ) , d ] )
for d in up_block_res_samples
]
2024-04-24 14:22:29 +02:00
# predict the noise residual
noise_pred = self . unet (
latent_model_input ,
t ,
encoder_hidden_states = prompt_embedsU ,
timestep_cond = timestep_cond ,
cross_attention_kwargs = self . cross_attention_kwargs ,
down_block_add_samples = down_block_res_samples ,
mid_block_add_sample = mid_block_res_sample ,
up_block_add_samples = up_block_res_samples ,
added_cond_kwargs = added_cond_kwargs ,
return_dict = False ,
) [ 0 ]
# perform guidance
if self . do_classifier_free_guidance :
noise_pred_uncond , noise_pred_text = noise_pred . chunk ( 2 )
2024-04-29 16:20:44 +02:00
noise_pred = noise_pred_uncond + self . guidance_scale * (
noise_pred_text - noise_pred_uncond
)
2024-04-24 14:22:29 +02:00
# compute the previous noisy sample x_t -> x_t-1
2024-04-29 16:20:44 +02:00
latents = self . scheduler . step (
noise_pred , t , latents , * * extra_step_kwargs , return_dict = False
) [ 0 ]
2024-04-24 14:22:29 +02:00
if callback_on_step_end is not None :
callback_kwargs = { }
for k in callback_on_step_end_tensor_inputs :
callback_kwargs [ k ] = locals ( ) [ k ]
callback_outputs = callback_on_step_end ( self , i , t , callback_kwargs )
latents = callback_outputs . pop ( " latents " , latents )
prompt_embeds = callback_outputs . pop ( " prompt_embeds " , prompt_embeds )
2024-04-29 16:20:44 +02:00
negative_prompt_embeds = callback_outputs . pop (
" negative_prompt_embeds " , negative_prompt_embeds
)
2024-04-24 14:22:29 +02:00
# call the callback, if provided
2024-04-29 16:20:44 +02:00
if i == len ( timesteps ) - 1 or (
( i + 1 ) > num_warmup_steps and ( i + 1 ) % self . scheduler . order == 0
) :
2024-04-24 14:22:29 +02:00
progress_bar . update ( )
if callback is not None and i % callback_steps == 0 :
step_idx = i / / getattr ( self . scheduler , " order " , 1 )
callback ( step_idx , t , latents )
# If we do sequential model offloading, let's offload unet and brushnet
# manually for max memory savings
if hasattr ( self , " final_offload_hook " ) and self . final_offload_hook is not None :
self . unet . to ( " cpu " )
self . brushnet . to ( " cpu " )
torch . cuda . empty_cache ( )
if not output_type == " latent " :
2024-04-29 16:20:44 +02:00
image = self . vae . decode (
latents / self . vae . config . scaling_factor ,
return_dict = False ,
generator = generator ,
) [ 0 ]
image , has_nsfw_concept = self . run_safety_checker (
image , device , prompt_embeds . dtype
)
2024-04-24 14:22:29 +02:00
else :
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None :
do_denormalize = [ True ] * image . shape [ 0 ]
else :
do_denormalize = [ not has_nsfw for has_nsfw in has_nsfw_concept ]
2024-04-29 16:20:44 +02:00
image = self . image_processor . postprocess (
image , output_type = output_type , do_denormalize = do_denormalize
)
2024-04-24 14:22:29 +02:00
# Offload all models
self . maybe_free_model_hooks ( )
if not return_dict :
return ( image , has_nsfw_concept )
2024-04-29 16:20:44 +02:00
return StableDiffusionPipelineOutput (
images = image , nsfw_content_detected = has_nsfw_concept
)