2022-09-15 16:21:27 +02:00
import PIL . Image
import cv2
import torch
from loguru import logger
2024-01-05 09:40:06 +01:00
from . base import DiffusionInpaintModel
from . helper . cpu_text_encoder import CPUTextEncoderWrapper
from . utils import handle_from_pretrained_exceptions
2024-01-05 08:19:23 +01:00
from iopaint . schema import InpaintRequest , ModelType
2022-09-15 16:21:27 +02:00
2023-01-27 13:59:22 +01:00
class SD ( DiffusionInpaintModel ) :
2022-11-03 13:46:58 +01:00
pad_mod = 8
2022-09-15 16:21:27 +02:00
min_size = 512
2023-11-15 01:50:35 +01:00
lcm_lora_id = " latent-consistency/lcm-lora-sdv1-5 "
2022-09-15 16:21:27 +02:00
def init_model ( self , device : torch . device , * * kwargs ) :
2022-10-20 15:01:14 +02:00
from diffusers . pipelines . stable_diffusion import StableDiffusionInpaintPipeline
2022-09-15 16:21:27 +02:00
2023-02-07 14:00:19 +01:00
fp16 = not kwargs . get ( " no_half " , False )
2022-09-29 03:42:19 +02:00
2023-12-01 03:15:35 +01:00
model_kwargs = { }
2023-02-07 14:00:19 +01:00
if kwargs [ " disable_nsfw " ] or kwargs . get ( " cpu_offload " , False ) :
logger . info ( " Disable Stable Diffusion Model NSFW checker " )
model_kwargs . update (
dict (
safety_checker = None ,
feature_extractor = None ,
requires_safety_checker = False ,
)
)
use_gpu = device == torch . device ( " cuda " ) and torch . cuda . is_available ( )
2023-01-04 14:27:37 +01:00
torch_dtype = torch . float16 if use_gpu and fp16 else torch . float32
2023-03-29 16:05:34 +02:00
2023-12-15 05:40:29 +01:00
if self . model_info . is_single_file_diffusers :
if self . model_info . model_type == ModelType . DIFFUSERS_SD :
model_kwargs [ " num_in_channels " ] = 4
else :
model_kwargs [ " num_in_channels " ] = 9
2023-11-16 07:09:08 +01:00
self . model = StableDiffusionInpaintPipeline . from_single_file (
2023-12-27 15:00:07 +01:00
self . model_id_or_path , dtype = torch_dtype , * * model_kwargs
2023-03-29 16:05:34 +02:00
)
else :
2023-12-27 15:00:07 +01:00
self . model = handle_from_pretrained_exceptions (
StableDiffusionInpaintPipeline . from_pretrained ,
pretrained_model_name_or_path = self . model_id_or_path ,
variant = " fp16 " ,
dtype = torch_dtype ,
2023-03-29 16:05:34 +02:00
* * model_kwargs ,
)
2023-01-05 15:07:39 +01:00
2024-01-08 14:49:18 +01:00
if torch . backends . mps . is_available ( ) :
# MPS: Recommended RAM < 64 GB https://huggingface.co/docs/diffusers/optimization/mps
# CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers. https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing
self . model . enable_attention_slicing ( )
2023-02-07 14:00:19 +01:00
if kwargs . get ( " cpu_offload " , False ) and use_gpu :
2023-01-07 01:52:11 +01:00
logger . info ( " Enable sequential cpu offload " )
2023-01-05 15:07:39 +01:00
self . model . enable_sequential_cpu_offload ( gpu_id = 0 )
else :
2023-01-18 11:34:10 +01:00
self . model = self . model . to ( device )
2023-02-07 14:00:19 +01:00
if kwargs [ " sd_cpu_textencoder " ] :
2023-01-05 15:07:39 +01:00
logger . info ( " Run Stable Diffusion TextEncoder on CPU " )
2023-02-07 14:00:19 +01:00
self . model . text_encoder = CPUTextEncoderWrapper (
self . model . text_encoder , torch_dtype
)
2022-09-29 06:20:55 +02:00
2022-10-15 16:32:25 +02:00
self . callback = kwargs . pop ( " callback " , None )
2022-09-15 16:21:27 +02:00
2023-12-30 16:36:44 +01:00
def forward ( self , image , mask , config : InpaintRequest ) :
2022-09-15 16:21:27 +02:00
""" Input image and output image have same size
image : [ H , W , C ] RGB
mask : [ H , W , 1 ] 255 means area to repaint
return : BGR IMAGE
"""
2023-11-15 01:50:35 +01:00
self . set_scheduler ( config )
2022-09-22 06:38:32 +02:00
2022-10-21 04:26:11 +02:00
img_h , img_w = image . shape [ : 2 ]
2022-09-15 16:21:27 +02:00
output = self . model (
2022-12-04 06:41:48 +01:00
image = PIL . Image . fromarray ( image ) ,
2022-09-15 16:21:27 +02:00
prompt = config . prompt ,
2022-11-08 14:58:48 +01:00
negative_prompt = config . negative_prompt ,
2022-09-15 16:21:27 +02:00
mask_image = PIL . Image . fromarray ( mask [ : , : , - 1 ] , mode = " L " ) ,
num_inference_steps = config . sd_steps ,
2023-11-14 07:02:15 +01:00
strength = config . sd_strength ,
2022-09-15 16:21:27 +02:00
guidance_scale = config . sd_guidance_scale ,
2023-08-30 15:30:11 +02:00
output_type = " np " ,
2024-01-02 10:13:11 +01:00
callback_on_step_end = self . callback ,
2022-10-21 04:26:11 +02:00
height = img_h ,
width = img_w ,
2023-03-01 14:44:02 +01:00
generator = torch . manual_seed ( config . sd_seed ) ,
2022-09-15 16:21:27 +02:00
) . images [ 0 ]
output = ( output * 255 ) . round ( ) . astype ( " uint8 " )
output = cv2 . cvtColor ( output , cv2 . COLOR_RGB2BGR )
return output
class SD15 ( SD ) :
2023-12-27 15:00:07 +01:00
name = " runwayml/stable-diffusion-inpainting "
2022-10-20 15:01:14 +02:00
model_id_or_path = " runwayml/stable-diffusion-inpainting "
2022-12-04 06:41:48 +01:00
2023-03-01 14:44:02 +01:00
class Anything4 ( SD ) :
2023-12-27 15:00:07 +01:00
name = " Sanster/anything-4.0-inpainting "
2023-03-01 14:44:02 +01:00
model_id_or_path = " Sanster/anything-4.0-inpainting "
class RealisticVision14 ( SD ) :
2023-12-27 15:00:07 +01:00
name = " Sanster/Realistic_Vision_V1.4-inpainting "
2023-03-01 14:44:02 +01:00
model_id_or_path = " Sanster/Realistic_Vision_V1.4-inpainting "
2022-12-04 06:41:48 +01:00
class SD2 ( SD ) :
2023-12-27 15:00:07 +01:00
name = " stabilityai/stable-diffusion-2-inpainting "
2022-12-04 06:41:48 +01:00
model_id_or_path = " stabilityai/stable-diffusion-2-inpainting "