add back enable_attention_slicing for mps device
This commit is contained in:
parent
3b40671e33
commit
5da47ee035
@ -96,6 +96,9 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
self.model.enable_attention_slicing()
|
||||||
|
|
||||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||||
logger.info("Enable sequential cpu offload")
|
logger.info("Enable sequential cpu offload")
|
||||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
|
@ -34,6 +34,8 @@ class InstructPix2Pix(DiffusionInpaintModel):
|
|||||||
self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
self.model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
|
||||||
self.name, variant="fp16", torch_dtype=torch_dtype, **model_kwargs
|
self.name, variant="fp16", torch_dtype=torch_dtype, **model_kwargs
|
||||||
)
|
)
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
self.model.enable_attention_slicing()
|
||||||
|
|
||||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||||
logger.info("Enable sequential cpu offload")
|
logger.info("Enable sequential cpu offload")
|
||||||
|
@ -26,6 +26,8 @@ class Kandinsky(DiffusionInpaintModel):
|
|||||||
self.model = AutoPipelineForInpainting.from_pretrained(
|
self.model = AutoPipelineForInpainting.from_pretrained(
|
||||||
self.name, **model_kwargs
|
self.name, **model_kwargs
|
||||||
).to(device)
|
).to(device)
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
self.model.enable_attention_slicing()
|
||||||
|
|
||||||
self.callback = kwargs.pop("callback", None)
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
|
@ -32,6 +32,9 @@ class PaintByExample(DiffusionInpaintModel):
|
|||||||
self.name, torch_dtype=torch_dtype, **model_kwargs
|
self.name, torch_dtype=torch_dtype, **model_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
self.model.enable_attention_slicing()
|
||||||
|
|
||||||
# TODO: gpu_id
|
# TODO: gpu_id
|
||||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||||
self.model.image_encoder = self.model.image_encoder.to(device)
|
self.model.image_encoder = self.model.image_encoder.to(device)
|
||||||
|
@ -29,7 +29,6 @@ class SD(DiffusionInpaintModel):
|
|||||||
requires_safety_checker=False,
|
requires_safety_checker=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
||||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||||
|
|
||||||
@ -51,6 +50,11 @@ class SD(DiffusionInpaintModel):
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
# MPS: Recommended RAM < 64 GB https://huggingface.co/docs/diffusers/optimization/mps
|
||||||
|
# CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers. https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing
|
||||||
|
self.model.enable_attention_slicing()
|
||||||
|
|
||||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||||
logger.info("Enable sequential cpu offload")
|
logger.info("Enable sequential cpu offload")
|
||||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
|
@ -50,6 +50,11 @@ class SDXL(DiffusionInpaintModel):
|
|||||||
variant="fp16",
|
variant="fp16",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
# MPS: Recommended RAM < 64 GB https://huggingface.co/docs/diffusers/optimization/mps
|
||||||
|
# CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers. https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing
|
||||||
|
self.model.enable_attention_slicing()
|
||||||
|
|
||||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||||
logger.info("Enable sequential cpu offload")
|
logger.info("Enable sequential cpu offload")
|
||||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
|
@ -24,7 +24,6 @@ from diffusers import (
|
|||||||
KDPM2AncestralDiscreteScheduler,
|
KDPM2AncestralDiscreteScheduler,
|
||||||
HeunDiscreteScheduler,
|
HeunDiscreteScheduler,
|
||||||
)
|
)
|
||||||
from diffusers.configuration_utils import FrozenDict
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from iopaint.schema import SDSampler
|
from iopaint.schema import SDSampler
|
||||||
|
Loading…
Reference in New Issue
Block a user