diff --git a/iopaint/model/controlnet.py b/iopaint/model/controlnet.py index 0360b67..d3894a6 100644 --- a/iopaint/model/controlnet.py +++ b/iopaint/model/controlnet.py @@ -13,7 +13,12 @@ from .helper.controlnet_preprocess import ( make_inpaint_control_image, ) from .helper.cpu_text_encoder import CPUTextEncoderWrapper -from .utils import get_scheduler, handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem +from .utils import ( + get_scheduler, + handle_from_pretrained_exceptions, + get_torch_dtype, + enable_low_mem, +) class ControlNet(DiffusionInpaintModel): diff --git a/iopaint/model/instruct_pix2pix.py b/iopaint/model/instruct_pix2pix.py index a792f0a..44c064a 100644 --- a/iopaint/model/instruct_pix2pix.py +++ b/iopaint/model/instruct_pix2pix.py @@ -6,7 +6,7 @@ from loguru import logger from iopaint.const import INSTRUCT_PIX2PIX_NAME from .base import DiffusionInpaintModel from iopaint.schema import InpaintRequest -from .utils import get_torch_dtype +from .utils import get_torch_dtype, enable_low_mem class InstructPix2Pix(DiffusionInpaintModel): @@ -33,8 +33,7 @@ class InstructPix2Pix(DiffusionInpaintModel): self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained( self.name, variant="fp16", torch_dtype=torch_dtype, **model_kwargs ) - if torch.backends.mps.is_available(): - self.model.enable_attention_slicing() + enable_low_mem(self.model, kwargs.get("low_mem", False)) if kwargs.get("cpu_offload", False) and use_gpu: logger.info("Enable sequential cpu offload") diff --git a/iopaint/model/kandinsky.py b/iopaint/model/kandinsky.py index fe67891..89dea09 100644 --- a/iopaint/model/kandinsky.py +++ b/iopaint/model/kandinsky.py @@ -6,7 +6,7 @@ import torch from iopaint.const import KANDINSKY22_NAME from .base import DiffusionInpaintModel from iopaint.schema import InpaintRequest -from .utils import get_torch_dtype +from .utils import get_torch_dtype, enable_low_mem class Kandinsky(DiffusionInpaintModel): @@ -25,8 +25,7 @@ class Kandinsky(DiffusionInpaintModel): self.model = AutoPipelineForInpainting.from_pretrained( self.name, **model_kwargs ).to(device) - if torch.backends.mps.is_available(): - self.model.enable_attention_slicing() + enable_low_mem(self.model, kwargs.get("low_mem", False)) self.callback = kwargs.pop("callback", None) diff --git a/iopaint/model/paint_by_example.py b/iopaint/model/paint_by_example.py index 2ae4cad..fd43263 100644 --- a/iopaint/model/paint_by_example.py +++ b/iopaint/model/paint_by_example.py @@ -7,7 +7,7 @@ from loguru import logger from iopaint.helper import decode_base64_to_image from .base import DiffusionInpaintModel from iopaint.schema import InpaintRequest -from .utils import get_torch_dtype +from .utils import get_torch_dtype, enable_low_mem class PaintByExample(DiffusionInpaintModel): @@ -30,9 +30,7 @@ class PaintByExample(DiffusionInpaintModel): self.model = DiffusionPipeline.from_pretrained( self.name, torch_dtype=torch_dtype, **model_kwargs ) - - if torch.backends.mps.is_available(): - self.model.enable_attention_slicing() + enable_low_mem(self.model, kwargs.get("low_mem", False)) # TODO: gpu_id if kwargs.get("cpu_offload", False) and use_gpu: diff --git a/iopaint/model/sd.py b/iopaint/model/sd.py index 95189f4..eb1a462 100644 --- a/iopaint/model/sd.py +++ b/iopaint/model/sd.py @@ -5,7 +5,7 @@ from loguru import logger from .base import DiffusionInpaintModel from .helper.cpu_text_encoder import CPUTextEncoderWrapper -from .utils import handle_from_pretrained_exceptions, get_torch_dtype +from .utils import handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem from iopaint.schema import InpaintRequest, ModelType @@ -48,10 +48,7 @@ class SD(DiffusionInpaintModel): **model_kwargs, ) - if torch.backends.mps.is_available(): - # MPS: Recommended RAM < 64 GB https://huggingface.co/docs/diffusers/optimization/mps - # CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers. https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing - self.model.enable_attention_slicing() + enable_low_mem(self.model, kwargs.get("low_mem", False)) if kwargs.get("cpu_offload", False) and use_gpu: logger.info("Enable sequential cpu offload") diff --git a/iopaint/model/sdxl.py b/iopaint/model/sdxl.py index cf1dab3..f54a721 100644 --- a/iopaint/model/sdxl.py +++ b/iopaint/model/sdxl.py @@ -9,7 +9,7 @@ from loguru import logger from iopaint.schema import InpaintRequest, ModelType from .base import DiffusionInpaintModel -from .utils import handle_from_pretrained_exceptions, get_torch_dtype +from .utils import handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem class SDXL(DiffusionInpaintModel): @@ -47,10 +47,7 @@ class SDXL(DiffusionInpaintModel): variant="fp16", ) - if torch.backends.mps.is_available(): - # MPS: Recommended RAM < 64 GB https://huggingface.co/docs/diffusers/optimization/mps - # CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers. https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing - self.model.enable_attention_slicing() + enable_low_mem(self.model, kwargs.get("low_mem", False)) if kwargs.get("cpu_offload", False) and use_gpu: logger.info("Enable sequential cpu offload") diff --git a/iopaint/tests/test_low_mem.py b/iopaint/tests/test_low_mem.py new file mode 100644 index 0000000..70e8801 --- /dev/null +++ b/iopaint/tests/test_low_mem.py @@ -0,0 +1,131 @@ +import os + +from loguru import logger + +from iopaint.tests.utils import check_device, get_config, assert_equal, current_dir + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +import pytest +import torch + +from iopaint.model_manager import ModelManager +from iopaint.schema import HDStrategy, SDSampler, FREEUConfig + + +@pytest.mark.parametrize("device", ["cuda", "mps"]) +def test_runway_sd_1_5_low_mem(device): + sd_steps = check_device(device) + model = ModelManager( + name="runwayml/stable-diffusion-inpainting", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + low_mem=True, + ) + + all_samplers = [member.value for member in SDSampler.__members__.values()] + print(all_samplers) + cfg = get_config( + strategy=HDStrategy.ORIGINAL, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + sd_sampler=SDSampler.ddim, + ) + + name = f"device_{device}" + + assert_equal( + model, + cfg, + f"runway_sd_{name}_low_mem.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize("sampler", [SDSampler.lcm]) +def test_runway_sd_lcm_lora_low_mem(device, sampler): + check_device(device) + + sd_steps = 5 + model = ModelManager( + name="runwayml/stable-diffusion-inpainting", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + low_mem=True, + ) + cfg = get_config( + strategy=HDStrategy.ORIGINAL, + prompt="face of a fox, sitting on a bench", + sd_steps=sd_steps, + sd_guidance_scale=2, + sd_lcm_lora=True, + ) + cfg.sd_sampler = sampler + + assert_equal( + model, + cfg, + f"runway_sd_1_5_lcm_lora_device_{device}_low_mem.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +def test_runway_sd_freeu(device, sampler): + sd_steps = check_device(device) + model = ModelManager( + name="runwayml/stable-diffusion-inpainting", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + low_mem=True, + ) + cfg = get_config( + strategy=HDStrategy.ORIGINAL, + prompt="face of a fox, sitting on a bench", + sd_steps=sd_steps, + sd_guidance_scale=7.5, + sd_freeu=True, + sd_freeu_config=FREEUConfig(), + ) + cfg.sd_sampler = sampler + + assert_equal( + model, + cfg, + f"runway_sd_1_5_freeu_device_{device}_low_mem.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +def test_runway_norm_sd_model(device, strategy, sampler): + sd_steps = check_device(device) + model = ModelManager( + name="runwayml/stable-diffusion-v1-5", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + low_mem=True, + ) + cfg = get_config( + strategy=strategy, prompt="face of a fox, sitting on a bench", sd_steps=sd_steps + ) + cfg.sd_sampler = sampler + + assert_equal( + model, + cfg, + f"runway_{device}_norm_sd_model_device_{device}_low_mem.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + )