From a49c3f86d31873730519f5cfc6ff0393c29d7413 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 8 Jan 2024 23:54:20 +0800 Subject: [PATCH] add enable_low_mem --- iopaint/api.py | 1 + iopaint/cli.py | 2 ++ iopaint/model/controlnet.py | 5 ++--- iopaint/model/power_paint/power_paint.py | 4 +++- iopaint/model/utils.py | 15 +++++++++++++++ iopaint/schema.py | 1 + iopaint/tests/test_outpainting.py | 1 + 7 files changed, 25 insertions(+), 4 deletions(-) diff --git a/iopaint/api.py b/iopaint/api.py index e8227e0..eff9ad7 100644 --- a/iopaint/api.py +++ b/iopaint/api.py @@ -356,6 +356,7 @@ class Api: name=self.config.model, device=torch.device(self.config.device), no_half=self.config.no_half, + low_mem=self.config.low_mem, disable_nsfw=self.config.disable_nsfw_checker, sd_cpu_textencoder=self.config.cpu_textencoder, cpu_offload=self.config.cpu_offload, diff --git a/iopaint/cli.py b/iopaint/cli.py index 6625a21..8cd7eb9 100644 --- a/iopaint/cli.py +++ b/iopaint/cli.py @@ -106,6 +106,7 @@ def start( file_okay=False, callback=setup_model_dir, ), + low_mem: bool = Option(False, help="Enable attention slicing and vae tiling to save memory."), no_half: bool = Option(False, help=NO_HALF_HELP), cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP), disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP), @@ -170,6 +171,7 @@ def start( port=port, model=model, no_half=no_half, + low_mem=low_mem, cpu_offload=cpu_offload, disable_nsfw_checker=disable_nsfw_checker, cpu_textencoder=cpu_textencoder, diff --git a/iopaint/model/controlnet.py b/iopaint/model/controlnet.py index ffaef1d..0360b67 100644 --- a/iopaint/model/controlnet.py +++ b/iopaint/model/controlnet.py @@ -13,7 +13,7 @@ from .helper.controlnet_preprocess import ( make_inpaint_control_image, ) from .helper.cpu_text_encoder import CPUTextEncoderWrapper -from .utils import get_scheduler, handle_from_pretrained_exceptions, get_torch_dtype +from .utils import get_scheduler, handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem class ControlNet(DiffusionInpaintModel): @@ -94,8 +94,7 @@ class ControlNet(DiffusionInpaintModel): **model_kwargs, ) - if torch.backends.mps.is_available(): - self.model.enable_attention_slicing() + enable_low_mem(self.model, kwargs.get("low_mem", False)) if kwargs.get("cpu_offload", False) and use_gpu: logger.info("Enable sequential cpu offload") diff --git a/iopaint/model/power_paint/power_paint.py b/iopaint/model/power_paint/power_paint.py index 616bb35..18fde76 100644 --- a/iopaint/model/power_paint/power_paint.py +++ b/iopaint/model/power_paint/power_paint.py @@ -6,7 +6,7 @@ from loguru import logger from ..base import DiffusionInpaintModel from ..helper.cpu_text_encoder import CPUTextEncoderWrapper -from ..utils import handle_from_pretrained_exceptions, get_torch_dtype +from ..utils import handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem from iopaint.schema import InpaintRequest from .powerpaint_tokenizer import add_task_to_prompt from ...const import POWERPAINT_NAME @@ -43,6 +43,8 @@ class PowerPaint(DiffusionInpaintModel): ) self.model.tokenizer = PowerPaintTokenizer(self.model.tokenizer) + enable_low_mem(self.model, kwargs.get("low_mem", False)) + if kwargs.get("cpu_offload", False) and use_gpu: logger.info("Enable sequential cpu offload") self.model.enable_sequential_cpu_offload(gpu_id=0) diff --git a/iopaint/model/utils.py b/iopaint/model/utils.py index 80e88b3..80ecbf9 100644 --- a/iopaint/model/utils.py +++ b/iopaint/model/utils.py @@ -1002,3 +1002,18 @@ def get_torch_dtype(device, no_half: bool): if device in ["cuda", "mps"] and use_fp16: return use_gpu, torch.float16 return use_gpu, torch.float32 + + +def enable_low_mem(pipe, enable: bool): + if torch.backends.mps.is_available(): + # https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing + # CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers. + if enable: + pipe.enable_attention_slicing("max") + else: + # https://huggingface.co/docs/diffusers/optimization/mps + # Devices with less than 64GB of memory are recommended to use enable_attention_slicing + pipe.enable_attention_slicing() + + if enable: + pipe.vae.enable_tiling() diff --git a/iopaint/schema.py b/iopaint/schema.py index c3a5588..8ee11b7 100644 --- a/iopaint/schema.py +++ b/iopaint/schema.py @@ -86,6 +86,7 @@ class ApiConfig(BaseModel): port: int model: str no_half: bool + low_mem: bool cpu_offload: bool disable_nsfw_checker: bool cpu_textencoder: bool diff --git a/iopaint/tests/test_outpainting.py b/iopaint/tests/test_outpainting.py index a84a1ec..024d701 100644 --- a/iopaint/tests/test_outpainting.py +++ b/iopaint/tests/test_outpainting.py @@ -114,6 +114,7 @@ def test_powerpaint_outpainting(name, device, rect): device=torch.device(device), disable_nsfw=True, sd_cpu_textencoder=False, + low_mem=True ) cfg = get_config( prompt="a dog sitting on a bench in the park",