add back enable_attention_slicing for mps device
This commit is contained in:
parent
3b40671e33
commit
5da47ee035
@ -96,6 +96,9 @@ class ControlNet(DiffusionInpaintModel):
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
self.model.enable_attention_slicing()
|
||||
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
logger.info("Enable sequential cpu offload")
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
|
@ -34,6 +34,8 @@ class InstructPix2Pix(DiffusionInpaintModel):
|
||||
self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||
self.name, variant="fp16", torch_dtype=torch_dtype, **model_kwargs
|
||||
)
|
||||
if torch.backends.mps.is_available():
|
||||
self.model.enable_attention_slicing()
|
||||
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
logger.info("Enable sequential cpu offload")
|
||||
|
@ -26,6 +26,8 @@ class Kandinsky(DiffusionInpaintModel):
|
||||
self.model = AutoPipelineForInpainting.from_pretrained(
|
||||
self.name, **model_kwargs
|
||||
).to(device)
|
||||
if torch.backends.mps.is_available():
|
||||
self.model.enable_attention_slicing()
|
||||
|
||||
self.callback = kwargs.pop("callback", None)
|
||||
|
||||
|
@ -32,6 +32,9 @@ class PaintByExample(DiffusionInpaintModel):
|
||||
self.name, torch_dtype=torch_dtype, **model_kwargs
|
||||
)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
self.model.enable_attention_slicing()
|
||||
|
||||
# TODO: gpu_id
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
self.model.image_encoder = self.model.image_encoder.to(device)
|
||||
|
@ -29,7 +29,6 @@ class SD(DiffusionInpaintModel):
|
||||
requires_safety_checker=False,
|
||||
)
|
||||
)
|
||||
|
||||
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||
|
||||
@ -51,6 +50,11 @@ class SD(DiffusionInpaintModel):
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
# MPS: Recommended RAM < 64 GB https://huggingface.co/docs/diffusers/optimization/mps
|
||||
# CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers. https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing
|
||||
self.model.enable_attention_slicing()
|
||||
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
logger.info("Enable sequential cpu offload")
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
|
@ -50,6 +50,11 @@ class SDXL(DiffusionInpaintModel):
|
||||
variant="fp16",
|
||||
)
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
# MPS: Recommended RAM < 64 GB https://huggingface.co/docs/diffusers/optimization/mps
|
||||
# CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers. https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing
|
||||
self.model.enable_attention_slicing()
|
||||
|
||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||
logger.info("Enable sequential cpu offload")
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
|
@ -24,7 +24,6 @@ from diffusers import (
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from loguru import logger
|
||||
|
||||
from iopaint.schema import SDSampler
|
||||
|
Loading…
Reference in New Issue
Block a user