2023-12-01 03:15:35 +01:00
import os
2023-11-14 07:19:56 +01:00
import PIL . Image
import cv2
import torch
2023-12-01 03:15:35 +01:00
from diffusers import AutoencoderKL
2023-11-14 07:19:56 +01:00
from loguru import logger
2024-01-05 08:19:23 +01:00
from iopaint . schema import InpaintRequest , ModelType
2023-11-14 07:19:56 +01:00
2024-01-05 09:40:06 +01:00
from . base import DiffusionInpaintModel
from . utils import handle_from_pretrained_exceptions
2023-11-14 07:19:56 +01:00
class SDXL ( DiffusionInpaintModel ) :
2023-12-27 15:00:07 +01:00
name = " diffusers/stable-diffusion-xl-1.0-inpainting-0.1 "
2023-11-14 07:19:56 +01:00
pad_mod = 8
min_size = 512
2023-11-15 01:50:35 +01:00
lcm_lora_id = " latent-consistency/lcm-lora-sdxl "
2023-12-01 03:15:35 +01:00
model_id_or_path = " diffusers/stable-diffusion-xl-1.0-inpainting-0.1 "
2023-11-14 07:19:56 +01:00
def init_model ( self , device : torch . device , * * kwargs ) :
2023-12-01 03:15:35 +01:00
from diffusers . pipelines import StableDiffusionXLInpaintPipeline
2023-11-14 07:19:56 +01:00
fp16 = not kwargs . get ( " no_half " , False )
use_gpu = device == torch . device ( " cuda " ) and torch . cuda . is_available ( )
torch_dtype = torch . float16 if use_gpu and fp16 else torch . float32
2023-12-15 05:40:29 +01:00
if self . model_info . model_type == ModelType . DIFFUSERS_SDXL :
num_in_channels = 4
else :
num_in_channels = 9
2023-12-01 03:15:35 +01:00
if os . path . isfile ( self . model_id_or_path ) :
self . model = StableDiffusionXLInpaintPipeline . from_single_file (
2023-12-15 05:40:29 +01:00
self . model_id_or_path ,
2023-12-27 15:00:07 +01:00
dtype = torch_dtype ,
2023-12-15 05:40:29 +01:00
num_in_channels = num_in_channels ,
2023-12-01 03:15:35 +01:00
)
else :
vae = AutoencoderKL . from_pretrained (
2023-12-19 06:16:30 +01:00
" madebyollin/sdxl-vae-fp16-fix " , torch_dtype = torch_dtype
2023-12-01 03:15:35 +01:00
)
2023-12-27 15:00:07 +01:00
self . model = handle_from_pretrained_exceptions (
StableDiffusionXLInpaintPipeline . from_pretrained ,
pretrained_model_name_or_path = self . model_id_or_path ,
2023-12-01 03:15:35 +01:00
torch_dtype = torch_dtype ,
vae = vae ,
2023-12-27 15:00:07 +01:00
variant = " fp16 " ,
2023-12-01 03:15:35 +01:00
)
2023-11-14 07:19:56 +01:00
2024-01-08 14:49:18 +01:00
if torch . backends . mps . is_available ( ) :
# MPS: Recommended RAM < 64 GB https://huggingface.co/docs/diffusers/optimization/mps
# CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers. https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing
self . model . enable_attention_slicing ( )
2023-11-14 07:19:56 +01:00
if kwargs . get ( " cpu_offload " , False ) and use_gpu :
logger . info ( " Enable sequential cpu offload " )
self . model . enable_sequential_cpu_offload ( gpu_id = 0 )
else :
self . model = self . model . to ( device )
if kwargs [ " sd_cpu_textencoder " ] :
logger . warning ( " Stable Diffusion XL not support run TextEncoder on CPU " )
self . callback = kwargs . pop ( " callback " , None )
2023-12-30 16:36:44 +01:00
def forward ( self , image , mask , config : InpaintRequest ) :
2023-11-14 07:19:56 +01:00
""" Input image and output image have same size
image : [ H , W , C ] RGB
mask : [ H , W , 1 ] 255 means area to repaint
return : BGR IMAGE
"""
2023-11-15 01:50:35 +01:00
self . set_scheduler ( config )
2023-11-14 07:19:56 +01:00
img_h , img_w = image . shape [ : 2 ]
output = self . model (
image = PIL . Image . fromarray ( image ) ,
prompt = config . prompt ,
negative_prompt = config . negative_prompt ,
mask_image = PIL . Image . fromarray ( mask [ : , : , - 1 ] , mode = " L " ) ,
num_inference_steps = config . sd_steps ,
strength = 0.999 if config . sd_strength == 1.0 else config . sd_strength ,
guidance_scale = config . sd_guidance_scale ,
output_type = " np " ,
2024-01-02 10:13:11 +01:00
callback_on_step_end = self . callback ,
2023-11-14 07:19:56 +01:00
height = img_h ,
width = img_w ,
generator = torch . manual_seed ( config . sd_seed ) ,
) . images [ 0 ]
output = ( output * 255 ) . round ( ) . astype ( " uint8 " )
output = cv2 . cvtColor ( output , cv2 . COLOR_RGB2BGR )
return output