add back enable_attention_slicing for mps device

This commit is contained in:
Qing 2024-01-08 21:49:18 +08:00
parent 3b40671e33
commit 5da47ee035
7 changed files with 20 additions and 2 deletions

View File

@ -96,6 +96,9 @@ class ControlNet(DiffusionInpaintModel):
**model_kwargs,
)
if torch.backends.mps.is_available():
self.model.enable_attention_slicing()
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)

View File

@ -34,6 +34,8 @@ class InstructPix2Pix(DiffusionInpaintModel):
self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
self.name, variant="fp16", torch_dtype=torch_dtype, **model_kwargs
)
if torch.backends.mps.is_available():
self.model.enable_attention_slicing()
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")

View File

@ -26,6 +26,8 @@ class Kandinsky(DiffusionInpaintModel):
self.model = AutoPipelineForInpainting.from_pretrained(
self.name, **model_kwargs
).to(device)
if torch.backends.mps.is_available():
self.model.enable_attention_slicing()
self.callback = kwargs.pop("callback", None)

View File

@ -32,6 +32,9 @@ class PaintByExample(DiffusionInpaintModel):
self.name, torch_dtype=torch_dtype, **model_kwargs
)
if torch.backends.mps.is_available():
self.model.enable_attention_slicing()
# TODO: gpu_id
if kwargs.get("cpu_offload", False) and use_gpu:
self.model.image_encoder = self.model.image_encoder.to(device)

View File

@ -29,7 +29,6 @@ class SD(DiffusionInpaintModel):
requires_safety_checker=False,
)
)
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
@ -51,6 +50,11 @@ class SD(DiffusionInpaintModel):
**model_kwargs,
)
if torch.backends.mps.is_available():
# MPS: Recommended RAM < 64 GB https://huggingface.co/docs/diffusers/optimization/mps
# CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers. https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing
self.model.enable_attention_slicing()
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)

View File

@ -50,6 +50,11 @@ class SDXL(DiffusionInpaintModel):
variant="fp16",
)
if torch.backends.mps.is_available():
# MPS: Recommended RAM < 64 GB https://huggingface.co/docs/diffusers/optimization/mps
# CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers. https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing
self.model.enable_attention_slicing()
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)

View File

@ -24,7 +24,6 @@ from diffusers import (
KDPM2AncestralDiscreteScheduler,
HeunDiscreteScheduler,
)
from diffusers.configuration_utils import FrozenDict
from loguru import logger
from iopaint.schema import SDSampler