optimize sd/paint_by_example modle VRAM usage

This commit is contained in:
Qing 2023-01-18 18:34:10 +08:00
parent 384f16dcd0
commit 148e97e8da
7 changed files with 61 additions and 24 deletions

View File

@ -102,7 +102,7 @@ if __name__ == "__main__":
name=args.name, name=args.name,
device=device, device=device,
sd_run_local=True, sd_run_local=True,
sd_disable_nsfw=True, disable_nsfw=True,
sd_cpu_textencoder=True, sd_cpu_textencoder=True,
hf_access_token="123" hf_access_token="123"
) )

View File

@ -22,16 +22,30 @@ class PaintByExample(InpaintModel):
use_gpu = device == torch.device('cuda') and torch.cuda.is_available() use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)} model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
logger.info("Disable Paint By Example Model NSFW checker")
model_kwargs.update(dict(
safety_checker=None,
requires_safety_checker=False
))
self.model = DiffusionPipeline.from_pretrained( self.model = DiffusionPipeline.from_pretrained(
"Fantasy-Studio/Paint-by-Example", "Fantasy-Studio/Paint-by-Example",
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
**model_kwargs **model_kwargs
) )
self.model = self.model.to(device)
self.model.enable_attention_slicing() self.model.enable_attention_slicing()
if kwargs.get('enable_xformers', False):
self.model.enable_xformers_memory_efficient_attention()
# TODO: gpu_id # TODO: gpu_id
if kwargs.get('cpu_offload', False) and torch.cuda.is_available(): if kwargs.get('cpu_offload', False) and use_gpu:
self.model.image_encoder = self.model.image_encoder.to(device)
self.model.enable_sequential_cpu_offload(gpu_id=0) self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
def forward(self, image, mask, config: Config): def forward(self, image, mask, config: Config):
"""Input image and output image have same size """Input image and output image have same size

View File

@ -37,10 +37,12 @@ class SD(InpaintModel):
fp16 = not kwargs.get('no_half', False) fp16 = not kwargs.get('no_half', False)
model_kwargs = {"local_files_only": kwargs.get('local_files_only', kwargs['sd_run_local'])} model_kwargs = {"local_files_only": kwargs.get('local_files_only', kwargs['sd_run_local'])}
if kwargs['sd_disable_nsfw']: if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
logger.info("Disable Stable Diffusion Model NSFW checker") logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(dict( model_kwargs.update(dict(
safety_checker=None, safety_checker=None,
feature_extractor=None,
requires_safety_checker=False
)) ))
use_gpu = device == torch.device('cuda') and torch.cuda.is_available() use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
@ -52,19 +54,19 @@ class SD(InpaintModel):
use_auth_token=kwargs["hf_access_token"], use_auth_token=kwargs["hf_access_token"],
**model_kwargs **model_kwargs
) )
self.model = self.model.to(device)
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing # https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing() self.model.enable_attention_slicing()
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention # https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
if kwargs.get('sd_enable_xformers', False): if kwargs.get('enable_xformers', False):
self.model.enable_xformers_memory_efficient_attention() self.model.enable_xformers_memory_efficient_attention()
if kwargs.get('cpu_offload', False) and torch.cuda.is_available(): if kwargs.get('cpu_offload', False) and use_gpu:
# TODO: gpu_id # TODO: gpu_id
logger.info("Enable sequential cpu offload") logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0) self.model.enable_sequential_cpu_offload(gpu_id=0)
else: else:
self.model = self.model.to(device)
if kwargs['sd_cpu_textencoder']: if kwargs['sd_cpu_textencoder']:
logger.info("Run Stable Diffusion TextEncoder on CPU") logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder, torch_dtype) self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder, torch_dtype)

View File

@ -28,6 +28,11 @@ def parse_args():
action="store_true", action="store_true",
help="Disable Stable Diffusion NSFW checker", help="Disable Stable Diffusion NSFW checker",
) )
parser.add_argument(
"--disable-nsfw",
action="store_true",
help="Disable Stable Diffusion/Paint By Example NSFW checker",
)
parser.add_argument( parser.add_argument(
"--sd-cpu-textencoder", "--sd-cpu-textencoder",
action="store_true", action="store_true",
@ -48,6 +53,11 @@ def parse_args():
action="store_true", action="store_true",
help="Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers" help="Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers"
) )
parser.add_argument(
"--enable-xformers",
action="store_true",
help="sd/paint_by_example model. Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers"
)
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu", "mps"]) parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu", "mps"])
parser.add_argument("--gui", action="store_true", help="Launch as desktop app") parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
parser.add_argument( parser.add_argument(
@ -78,7 +88,8 @@ def parse_args():
if args.device == "cuda": if args.device == "cuda":
import torch import torch
if torch.cuda.is_available() is False: if torch.cuda.is_available() is False:
parser.error("torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation") parser.error(
"torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation")
if args.model_dir is not None: if args.model_dir is not None:
if os.path.isfile(args.model_dir): if os.path.isfile(args.model_dir):
@ -90,7 +101,6 @@ def parse_args():
os.environ["XDG_CACHE_HOME"] = args.model_dir os.environ["XDG_CACHE_HOME"] = args.model_dir
if args.input is not None: if args.input is not None:
if not os.path.exists(args.input): if not os.path.exists(args.input):
parser.error(f"invalid --input: {args.input} not exists") parser.error(f"invalid --input: {args.input} not exists")

View File

@ -392,12 +392,12 @@ def main(args):
device=device, device=device,
no_half=args.no_half, no_half=args.no_half,
hf_access_token=args.hf_access_token, hf_access_token=args.hf_access_token,
sd_disable_nsfw=args.sd_disable_nsfw, disable_nsfw=args.sd_disable_nsfw or args.disable_nsfw,
sd_cpu_textencoder=args.sd_cpu_textencoder, sd_cpu_textencoder=args.sd_cpu_textencoder,
sd_run_local=args.sd_run_local, sd_run_local=args.sd_run_local,
local_files_only=args.local_files_only, local_files_only=args.local_files_only,
cpu_offload=args.cpu_offload, cpu_offload=args.cpu_offload,
sd_enable_xformers=args.sd_enable_xformers, enable_xformers=args.sd_enable_xformers or args.enable_xformers,
callback=diffuser_callback, callback=diffuser_callback,
) )

View File

@ -37,7 +37,7 @@ def assert_equal(
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_paint_by_example(strategy): def test_paint_by_example(strategy):
model = ModelManager(name="paint_by_example", device=device) model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True)
cfg = get_config(strategy, paint_by_example_steps=30) cfg = get_config(strategy, paint_by_example_steps=30)
assert_equal( assert_equal(
model, model,
@ -50,9 +50,22 @@ def test_paint_by_example(strategy):
) )
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_paint_by_example_disable_nsfw(strategy):
model = ModelManager(name="paint_by_example", device=device, disable_nsfw=False)
cfg = get_config(strategy, paint_by_example_steps=30)
assert_equal(
model,
cfg,
f"paint_by_example_{strategy.capitalize()}_disable_nsfw.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_paint_by_example_sd_scale(strategy): def test_paint_by_example_sd_scale(strategy):
model = ModelManager(name="paint_by_example", device=device) model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True)
cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85) cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
assert_equal( assert_equal(
model, model,
@ -67,7 +80,7 @@ def test_paint_by_example_sd_scale(strategy):
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_paint_by_example_cpu_offload(strategy): def test_paint_by_example_cpu_offload(strategy):
model = ModelManager(name="paint_by_example", device=device, cpu_offload=True) model = ModelManager(name="paint_by_example", device=device, cpu_offload=True, disable_nsfw=False)
cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85) cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
assert_equal( assert_equal(
model, model,
@ -75,14 +88,12 @@ def test_paint_by_example_cpu_offload(strategy):
f"paint_by_example_{strategy.capitalize()}_cpu_offload.png", f"paint_by_example_{strategy.capitalize()}_cpu_offload.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
fy=0.9,
fx=1.3
) )
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_paint_by_example_cpu_offload_cpu_device(strategy): def test_paint_by_example_cpu_offload_cpu_device(strategy):
model = ModelManager(name="paint_by_example", device = torch.device('cpu'), cpu_offload=True) model = ModelManager(name="paint_by_example", device=torch.device('cpu'), cpu_offload=True, disable_nsfw=True)
cfg = get_config(strategy, paint_by_example_steps=1, sd_scale=0.85) cfg = get_config(strategy, paint_by_example_steps=1, sd_scale=0.85)
assert_equal( assert_equal(
model, model,

View File

@ -14,7 +14,7 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device) device = torch.device(device)
@pytest.mark.parametrize("sd_device", ['cpu', 'cuda']) @pytest.mark.parametrize("sd_device", ['cuda'])
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("sampler", [SDSampler.ddim]) @pytest.mark.parametrize("sampler", [SDSampler.ddim])
@pytest.mark.parametrize("cpu_textencoder", [True, False]) @pytest.mark.parametrize("cpu_textencoder", [True, False])
@ -31,7 +31,7 @@ def test_runway_sd_1_5_ddim(sd_device, strategy, sampler, cpu_textencoder, disab
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="", hf_access_token="",
sd_run_local=True, sd_run_local=True,
sd_disable_nsfw=disable_nsfw, disable_nsfw=disable_nsfw,
sd_cpu_textencoder=cpu_textencoder, sd_cpu_textencoder=cpu_textencoder,
callback=callback) callback=callback)
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps) cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps)
@ -66,7 +66,7 @@ def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_ns
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="", hf_access_token="",
sd_run_local=True, sd_run_local=True,
sd_disable_nsfw=disable_nsfw, disable_nsfw=disable_nsfw,
sd_cpu_textencoder=cpu_textencoder, sd_cpu_textencoder=cpu_textencoder,
callback=callback) callback=callback)
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps) cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps)
@ -99,7 +99,7 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="", hf_access_token="",
sd_run_local=True, sd_run_local=True,
sd_disable_nsfw=False, disable_nsfw=False,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
callback=callback) callback=callback)
cfg = get_config( cfg = get_config(
@ -137,7 +137,7 @@ def test_runway_sd_1_5_sd_scale(sd_device, strategy, sampler, cpu_textencoder, d
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="", hf_access_token="",
sd_run_local=True, sd_run_local=True,
sd_disable_nsfw=disable_nsfw, disable_nsfw=disable_nsfw,
sd_cpu_textencoder=cpu_textencoder) sd_cpu_textencoder=cpu_textencoder)
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85) cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85)
cfg.sd_sampler = sampler cfg.sd_sampler = sampler
@ -166,7 +166,7 @@ def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="", hf_access_token="",
sd_run_local=True, sd_run_local=True,
sd_disable_nsfw=False, disable_nsfw=True,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
cpu_offload=True) cpu_offload=True)
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85) cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85)
@ -191,7 +191,7 @@ def test_runway_sd_1_5_cpu_offload_cpu_device(sd_device, strategy, sampler):
device=torch.device(sd_device), device=torch.device(sd_device),
hf_access_token="", hf_access_token="",
sd_run_local=True, sd_run_local=True,
sd_disable_nsfw=False, disable_nsfw=False,
sd_cpu_textencoder=False, sd_cpu_textencoder=False,
cpu_offload=True) cpu_offload=True)
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=1, sd_scale=0.85) cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=1, sd_scale=0.85)