add enable_low_mem
This commit is contained in:
parent
a71c3fbe1b
commit
a49c3f86d3
@ -356,6 +356,7 @@ class Api:
|
|||||||
name=self.config.model,
|
name=self.config.model,
|
||||||
device=torch.device(self.config.device),
|
device=torch.device(self.config.device),
|
||||||
no_half=self.config.no_half,
|
no_half=self.config.no_half,
|
||||||
|
low_mem=self.config.low_mem,
|
||||||
disable_nsfw=self.config.disable_nsfw_checker,
|
disable_nsfw=self.config.disable_nsfw_checker,
|
||||||
sd_cpu_textencoder=self.config.cpu_textencoder,
|
sd_cpu_textencoder=self.config.cpu_textencoder,
|
||||||
cpu_offload=self.config.cpu_offload,
|
cpu_offload=self.config.cpu_offload,
|
||||||
|
@ -106,6 +106,7 @@ def start(
|
|||||||
file_okay=False,
|
file_okay=False,
|
||||||
callback=setup_model_dir,
|
callback=setup_model_dir,
|
||||||
),
|
),
|
||||||
|
low_mem: bool = Option(False, help="Enable attention slicing and vae tiling to save memory."),
|
||||||
no_half: bool = Option(False, help=NO_HALF_HELP),
|
no_half: bool = Option(False, help=NO_HALF_HELP),
|
||||||
cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP),
|
cpu_offload: bool = Option(False, help=CPU_OFFLOAD_HELP),
|
||||||
disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP),
|
disable_nsfw_checker: bool = Option(False, help=DISABLE_NSFW_HELP),
|
||||||
@ -170,6 +171,7 @@ def start(
|
|||||||
port=port,
|
port=port,
|
||||||
model=model,
|
model=model,
|
||||||
no_half=no_half,
|
no_half=no_half,
|
||||||
|
low_mem=low_mem,
|
||||||
cpu_offload=cpu_offload,
|
cpu_offload=cpu_offload,
|
||||||
disable_nsfw_checker=disable_nsfw_checker,
|
disable_nsfw_checker=disable_nsfw_checker,
|
||||||
cpu_textencoder=cpu_textencoder,
|
cpu_textencoder=cpu_textencoder,
|
||||||
|
@ -13,7 +13,7 @@ from .helper.controlnet_preprocess import (
|
|||||||
make_inpaint_control_image,
|
make_inpaint_control_image,
|
||||||
)
|
)
|
||||||
from .helper.cpu_text_encoder import CPUTextEncoderWrapper
|
from .helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||||
from .utils import get_scheduler, handle_from_pretrained_exceptions, get_torch_dtype
|
from .utils import get_scheduler, handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem
|
||||||
|
|
||||||
|
|
||||||
class ControlNet(DiffusionInpaintModel):
|
class ControlNet(DiffusionInpaintModel):
|
||||||
@ -94,8 +94,7 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if torch.backends.mps.is_available():
|
enable_low_mem(self.model, kwargs.get("low_mem", False))
|
||||||
self.model.enable_attention_slicing()
|
|
||||||
|
|
||||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||||
logger.info("Enable sequential cpu offload")
|
logger.info("Enable sequential cpu offload")
|
||||||
|
@ -6,7 +6,7 @@ from loguru import logger
|
|||||||
|
|
||||||
from ..base import DiffusionInpaintModel
|
from ..base import DiffusionInpaintModel
|
||||||
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
|
from ..helper.cpu_text_encoder import CPUTextEncoderWrapper
|
||||||
from ..utils import handle_from_pretrained_exceptions, get_torch_dtype
|
from ..utils import handle_from_pretrained_exceptions, get_torch_dtype, enable_low_mem
|
||||||
from iopaint.schema import InpaintRequest
|
from iopaint.schema import InpaintRequest
|
||||||
from .powerpaint_tokenizer import add_task_to_prompt
|
from .powerpaint_tokenizer import add_task_to_prompt
|
||||||
from ...const import POWERPAINT_NAME
|
from ...const import POWERPAINT_NAME
|
||||||
@ -43,6 +43,8 @@ class PowerPaint(DiffusionInpaintModel):
|
|||||||
)
|
)
|
||||||
self.model.tokenizer = PowerPaintTokenizer(self.model.tokenizer)
|
self.model.tokenizer = PowerPaintTokenizer(self.model.tokenizer)
|
||||||
|
|
||||||
|
enable_low_mem(self.model, kwargs.get("low_mem", False))
|
||||||
|
|
||||||
if kwargs.get("cpu_offload", False) and use_gpu:
|
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||||
logger.info("Enable sequential cpu offload")
|
logger.info("Enable sequential cpu offload")
|
||||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
|
@ -1002,3 +1002,18 @@ def get_torch_dtype(device, no_half: bool):
|
|||||||
if device in ["cuda", "mps"] and use_fp16:
|
if device in ["cuda", "mps"] and use_fp16:
|
||||||
return use_gpu, torch.float16
|
return use_gpu, torch.float16
|
||||||
return use_gpu, torch.float32
|
return use_gpu, torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
def enable_low_mem(pipe, enable: bool):
|
||||||
|
if torch.backends.mps.is_available():
|
||||||
|
# https://huggingface.co/docs/diffusers/v0.25.0/en/api/pipelines/stable_diffusion/image_variation#diffusers.StableDiffusionImageVariationPipeline.enable_attention_slicing
|
||||||
|
# CUDA: Don't enable attention slicing if you're already using `scaled_dot_product_attention` (SDPA) from PyTorch 2.0 or xFormers.
|
||||||
|
if enable:
|
||||||
|
pipe.enable_attention_slicing("max")
|
||||||
|
else:
|
||||||
|
# https://huggingface.co/docs/diffusers/optimization/mps
|
||||||
|
# Devices with less than 64GB of memory are recommended to use enable_attention_slicing
|
||||||
|
pipe.enable_attention_slicing()
|
||||||
|
|
||||||
|
if enable:
|
||||||
|
pipe.vae.enable_tiling()
|
||||||
|
@ -86,6 +86,7 @@ class ApiConfig(BaseModel):
|
|||||||
port: int
|
port: int
|
||||||
model: str
|
model: str
|
||||||
no_half: bool
|
no_half: bool
|
||||||
|
low_mem: bool
|
||||||
cpu_offload: bool
|
cpu_offload: bool
|
||||||
disable_nsfw_checker: bool
|
disable_nsfw_checker: bool
|
||||||
cpu_textencoder: bool
|
cpu_textencoder: bool
|
||||||
|
@ -114,6 +114,7 @@ def test_powerpaint_outpainting(name, device, rect):
|
|||||||
device=torch.device(device),
|
device=torch.device(device),
|
||||||
disable_nsfw=True,
|
disable_nsfw=True,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
|
low_mem=True
|
||||||
)
|
)
|
||||||
cfg = get_config(
|
cfg = get_config(
|
||||||
prompt="a dog sitting on a bench in the park",
|
prompt="a dog sitting on a bench in the park",
|
||||||
|
Loading…
Reference in New Issue
Block a user