add enable_low_mem

This commit is contained in:
Qing 2024-01-08 23:54:20 +08:00
parent a71c3fbe1b
commit a49c3f86d3
7 changed files with 25 additions and 4 deletions

View File

@ -356,6 +356,7 @@ class Api:
name=self.config.model, name=self.config.model,
device=torch.device(self.config.device), device=torch.device(self.config.device),
no_half=self.config.no_half, no_half=self.config.no_half,
low_mem=self.config.low_mem,
disable_nsfw=self.config.disable_nsfw_checker, disable_nsfw=self.config.disable_nsfw_checker,
sd_cpu_textencoder=self.config.cpu_textencoder, sd_cpu_textencoder=self.config.cpu_textencoder,
cpu_offload=self.config.cpu_offload, cpu_offload=self.config.cpu_offload,

View File

@ -106,6 +106,7 @@ def start(
file_okay=False, file_okay=False,
callback=setup_model_dir, callback=setup_model_dir,
), ),
low_mem: bool = Option(False, help="Enable attention slicing and vae tiling to save memory."),
no_half: bool = Option(False, help=NO_HALF_HELP), no_half: bool = Option(False, help=NO_HALF_HELP),
cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP), cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP),
disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP), disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP),
@ -170,6 +171,7 @@ def start(
port=port, port=port,
model=model, model=model,
no_half=no_half, no_half=no_half,
low_mem=low_mem,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
disable_nsfw_checker=disable_nsfw_checker, disable_nsfw_checker=disable_nsfw_checker,
cpu_textencoder=cpu_textencoder, cpu_textencoder=cpu_textencoder,

View File

@ -13,7 +13,7 @@ from .helper.controlnet_preprocess import (
make_inpaint_control_image, make_inpaint_control_image,
) )
from .helper.cpu_text_encoder import CPUTextEncoderWrapper from .helper.cpu_text_encoder import CPUTextEncoderWrapper
from .utils import get_scheduler, handle_from_pretrained_exceptions, get_torch_dtype from .utils import get_scheduler, handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem
class ControlNet(DiffusionInpaintModel): class ControlNet(DiffusionInpaintModel):
@ -94,8 +94,7 @@ class ControlNet(DiffusionInpaintModel):
**model_kwargs, **model_kwargs,
) )
if torch.backends.mps.is_available(): enable_low_mem(self.model, kwargs.get("low_mem", False))
self.model.enable_attention_slicing()
if kwargs.get("cpu_offload", False) and use_gpu: if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload") logger.info("Enable sequential cpu offload")

View File

@ -6,7 +6,7 @@ from loguru import logger
from ..base import DiffusionInpaintModel from ..base import DiffusionInpaintModel
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
from ..utils import handle_from_pretrained_exceptions, get_torch_dtype from ..utils import handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem
from iopaint.schema import InpaintRequest from iopaint.schema import InpaintRequest
from .powerpaint_tokenizer import add_task_to_prompt from .powerpaint_tokenizer import add_task_to_prompt
from ...const import POWERPAINT_NAME from ...const import POWERPAINT_NAME
@ -43,6 +43,8 @@ class PowerPaint(DiffusionInpaintModel):
) )
self.model.tokenizer = PowerPaintTokenizer(self.model.tokenizer) self.model.tokenizer = PowerPaintTokenizer(self.model.tokenizer)
enable_low_mem(self.model, kwargs.get("low_mem", False))
if kwargs.get("cpu_offload", False) and use_gpu: if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload") logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0) self.model.enable_sequential_cpu_offload(gpu_id=0)

View File

@ -1002,3 +1002,18 @@ def get_torch_dtype(device, no_half: bool):
if device in ["cuda", "mps"] and use_fp16: if device in ["cuda", "mps"] and use_fp16:
return use_gpu, torch.float16 return use_gpu, torch.float16
return use_gpu, torch.float32 return use_gpu, torch.float32
def enable_low_mem(pipe, enable: bool):
if torch.backends.mps.is_available():
# https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing
# CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers.
if enable:
pipe.enable_attention_slicing("max")
else:
# https://huggingface.co/docs/diffusers/optimization/mps
# Devices with less than 64GB of memory are recommended to use enable_attention_slicing
pipe.enable_attention_slicing()
if enable:
pipe.vae.enable_tiling()

View File

@ -86,6 +86,7 @@ class ApiConfig(BaseModel):
port: int port: int
model: str model: str
no_half: bool no_half: bool
low_mem: bool
cpu_offload: bool cpu_offload: bool
disable_nsfw_checker: bool disable_nsfw_checker: bool
cpu_textencoder: bool cpu_textencoder: bool

View File

@ -114,6 +114,7 @@ def test_powerpaint_outpainting(name, device, rect):
device=torch.device(device), device=torch.device(device),
disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
low_mem=True
) )
cfg = get_config( cfg = get_config(
prompt="a dog sitting on a bench in the park", prompt="a dog sitting on a bench in the park",