From 38b6edacf02ffb5b869a1fb34d96cb82d892e0f1 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 10 Jan 2024 13:34:11 +0800 Subject: [PATCH] sdxl support cpu_text_encoder --- iopaint/model/helper/cpu_text_encoder.py | 15 +++++++---- iopaint/model/sdxl.py | 13 +++++++--- iopaint/tests/test_sd_model.py | 27 ++++++++++++++++++++ iopaint/tests/test_sdxl.py | 32 ++++++++++++++++++++++++ iopaint/tests/utils.py | 2 +- 5 files changed, 80 insertions(+), 9 deletions(-) diff --git a/iopaint/model/helper/cpu_text_encoder.py b/iopaint/model/helper/cpu_text_encoder.py index ed01630..889bc97 100644 --- a/iopaint/model/helper/cpu_text_encoder.py +++ b/iopaint/model/helper/cpu_text_encoder.py @@ -8,6 +8,7 @@ class CPUTextEncoderWrapper(PreTrainedModel): def __init__(self, text_encoder, torch_dtype): super().__init__(text_encoder.config) self.config = text_encoder.config + # cpu not support float16 self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True) self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) self.torch_dtype = torch_dtype @@ -16,11 +17,15 @@ class CPUTextEncoderWrapper(PreTrainedModel): def __call__(self, x, **kwargs): input_device = x.device - return [ - self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0] - .to(input_device) - .to(self.torch_dtype) - ] + original_output = self.text_encoder(x.to(self.text_encoder.device), **kwargs) + for k, v in original_output.items(): + if isinstance(v, tuple): + original_output[k] = [ + v[i].to(input_device).to(self.torch_dtype) for i in range(len(v)) + ] + else: + original_output[k] = v.to(input_device).to(self.torch_dtype) + return original_output @property def dtype(self): diff --git a/iopaint/model/sdxl.py b/iopaint/model/sdxl.py index f8795d7..277dec0 100644 --- a/iopaint/model/sdxl.py +++ b/iopaint/model/sdxl.py @@ -9,6 +9,7 @@ from loguru import logger from iopaint.schema import InpaintRequest, ModelType from .base import DiffusionInpaintModel +from .helper.cpu_text_encoder import CPUTextEncoderWrapper from .utils import handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem @@ -37,11 +38,11 @@ class SDXL(DiffusionInpaintModel): ) else: model_kwargs = {**kwargs.get("pipe_components", {})} - if 'vae' not in model_kwargs: + if "vae" not in model_kwargs: vae = AutoencoderKL.from_pretrained( "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype ) - model_kwargs['vae'] = vae + model_kwargs["vae"] = vae self.model = handle_from_pretrained_exceptions( StableDiffusionXLInpaintPipeline.from_pretrained, pretrained_model_name_or_path=self.model_id_or_path, @@ -58,7 +59,13 @@ class SDXL(DiffusionInpaintModel): else: self.model = self.model.to(device) if kwargs["sd_cpu_textencoder"]: - logger.warning("Stable Diffusion XL not support run TextEncoder on CPU") + logger.info("Run Stable Diffusion TextEncoder on CPU") + self.model.text_encoder = CPUTextEncoderWrapper( + self.model.text_encoder, torch_dtype + ) + self.model.text_encoder_2 = CPUTextEncoderWrapper( + self.model.text_encoder_2, torch_dtype + ) self.callback = kwargs.pop("callback", None) diff --git a/iopaint/tests/test_sd_model.py b/iopaint/tests/test_sd_model.py index 4aeaaa2..c6e4d11 100644 --- a/iopaint/tests/test_sd_model.py +++ b/iopaint/tests/test_sd_model.py @@ -147,6 +147,33 @@ def test_runway_sd_sd_strength(device, strategy, sampler): ) +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +def test_runway_sd_cpu_textencoder(device, strategy, sampler): + sd_steps = check_device(device) + model = ModelManager( + name="runwayml/stable-diffusion-inpainting", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=True, + ) + cfg = get_config( + strategy=strategy, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + sd_sampler=sampler, + ) + + assert_equal( + model, + cfg, + f"runway_sd_device_{device}_cpu_textencoder.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + @pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("sampler", [SDSampler.ddim]) diff --git a/iopaint/tests/test_sdxl.py b/iopaint/tests/test_sdxl.py index d99fc44..e236948 100644 --- a/iopaint/tests/test_sdxl.py +++ b/iopaint/tests/test_sdxl.py @@ -44,6 +44,38 @@ def test_sdxl(device, strategy, sampler): ) +@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +def test_sdxl_cpu_text_encoder(device, strategy, sampler): + sd_steps = check_device(device) + + model = ModelManager( + name="diffusers/stable-diffusion-xl-1.0-inpainting-0.1", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=True, + ) + cfg = get_config( + strategy=strategy, + prompt="face of a fox, sitting on a bench", + sd_steps=sd_steps, + sd_strength=1.0, + sd_guidance_scale=7.0, + ) + cfg.sd_sampler = sampler + + assert_equal( + model, + cfg, + f"sdxl_device_{device}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + fx=2, + fy=2, + ) + + @pytest.mark.parametrize("device", ["cuda", "mps"]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("sampler", [SDSampler.ddim]) diff --git a/iopaint/tests/utils.py b/iopaint/tests/utils.py index dc3afda..08f4aeb 100644 --- a/iopaint/tests/utils.py +++ b/iopaint/tests/utils.py @@ -17,7 +17,7 @@ def check_device(device: str) -> int: pytest.skip("CUDA is not available, skip test on cuda") if device == "mps" and not torch.backends.mps.is_available(): pytest.skip("mps is not available, skip test on mps") - steps = 1 if device == "cpu" else 20 + steps = 2 if device == "cpu" else 20 return steps