optimize sd/paint_by_example modle VRAM usage
This commit is contained in:
parent
384f16dcd0
commit
148e97e8da
@ -102,7 +102,7 @@ if __name__ == "__main__":
|
||||
name=args.name,
|
||||
device=device,
|
||||
sd_run_local=True,
|
||||
sd_disable_nsfw=True,
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=True,
|
||||
hf_access_token="123"
|
||||
)
|
||||
|
@ -22,16 +22,30 @@ class PaintByExample(InpaintModel):
|
||||
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
|
||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||
model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
|
||||
|
||||
if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
|
||||
logger.info("Disable Paint By Example Model NSFW checker")
|
||||
model_kwargs.update(dict(
|
||||
safety_checker=None,
|
||||
requires_safety_checker=False
|
||||
))
|
||||
|
||||
self.model = DiffusionPipeline.from_pretrained(
|
||||
"Fantasy-Studio/Paint-by-Example",
|
||||
torch_dtype=torch_dtype,
|
||||
**model_kwargs
|
||||
)
|
||||
self.model = self.model.to(device)
|
||||
|
||||
self.model.enable_attention_slicing()
|
||||
if kwargs.get('enable_xformers', False):
|
||||
self.model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
# TODO: gpu_id
|
||||
if kwargs.get('cpu_offload', False) and torch.cuda.is_available():
|
||||
if kwargs.get('cpu_offload', False) and use_gpu:
|
||||
self.model.image_encoder = self.model.image_encoder.to(device)
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
else:
|
||||
self.model = self.model.to(device)
|
||||
|
||||
def forward(self, image, mask, config: Config):
|
||||
"""Input image and output image have same size
|
||||
|
@ -37,10 +37,12 @@ class SD(InpaintModel):
|
||||
fp16 = not kwargs.get('no_half', False)
|
||||
|
||||
model_kwargs = {"local_files_only": kwargs.get('local_files_only', kwargs['sd_run_local'])}
|
||||
if kwargs['sd_disable_nsfw']:
|
||||
if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
|
||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||
model_kwargs.update(dict(
|
||||
safety_checker=None,
|
||||
feature_extractor=None,
|
||||
requires_safety_checker=False
|
||||
))
|
||||
|
||||
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
|
||||
@ -52,19 +54,19 @@ class SD(InpaintModel):
|
||||
use_auth_token=kwargs["hf_access_token"],
|
||||
**model_kwargs
|
||||
)
|
||||
self.model = self.model.to(device)
|
||||
|
||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||
self.model.enable_attention_slicing()
|
||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
|
||||
if kwargs.get('sd_enable_xformers', False):
|
||||
if kwargs.get('enable_xformers', False):
|
||||
self.model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
if kwargs.get('cpu_offload', False) and torch.cuda.is_available():
|
||||
if kwargs.get('cpu_offload', False) and use_gpu:
|
||||
# TODO: gpu_id
|
||||
logger.info("Enable sequential cpu offload")
|
||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||
else:
|
||||
self.model = self.model.to(device)
|
||||
if kwargs['sd_cpu_textencoder']:
|
||||
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||
self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder, torch_dtype)
|
||||
|
@ -28,6 +28,11 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Disable Stable Diffusion NSFW checker",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-nsfw",
|
||||
action="store_true",
|
||||
help="Disable Stable Diffusion/Paint By Example NSFW checker",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sd-cpu-textencoder",
|
||||
action="store_true",
|
||||
@ -48,6 +53,11 @@ def parse_args():
|
||||
action="store_true",
|
||||
help="Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-xformers",
|
||||
action="store_true",
|
||||
help="sd/paint_by_example model. Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers"
|
||||
)
|
||||
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu", "mps"])
|
||||
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
||||
parser.add_argument(
|
||||
@ -78,7 +88,8 @@ def parse_args():
|
||||
if args.device == "cuda":
|
||||
import torch
|
||||
if torch.cuda.is_available() is False:
|
||||
parser.error("torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation")
|
||||
parser.error(
|
||||
"torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation")
|
||||
|
||||
if args.model_dir is not None:
|
||||
if os.path.isfile(args.model_dir):
|
||||
@ -90,7 +101,6 @@ def parse_args():
|
||||
|
||||
os.environ["XDG_CACHE_HOME"] = args.model_dir
|
||||
|
||||
|
||||
if args.input is not None:
|
||||
if not os.path.exists(args.input):
|
||||
parser.error(f"invalid --input: {args.input} not exists")
|
||||
|
@ -392,12 +392,12 @@ def main(args):
|
||||
device=device,
|
||||
no_half=args.no_half,
|
||||
hf_access_token=args.hf_access_token,
|
||||
sd_disable_nsfw=args.sd_disable_nsfw,
|
||||
disable_nsfw=args.sd_disable_nsfw or args.disable_nsfw,
|
||||
sd_cpu_textencoder=args.sd_cpu_textencoder,
|
||||
sd_run_local=args.sd_run_local,
|
||||
local_files_only=args.local_files_only,
|
||||
cpu_offload=args.cpu_offload,
|
||||
sd_enable_xformers=args.sd_enable_xformers,
|
||||
enable_xformers=args.sd_enable_xformers or args.enable_xformers,
|
||||
callback=diffuser_callback,
|
||||
)
|
||||
|
||||
|
@ -37,7 +37,7 @@ def assert_equal(
|
||||
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
def test_paint_by_example(strategy):
|
||||
model = ModelManager(name="paint_by_example", device=device)
|
||||
model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True)
|
||||
cfg = get_config(strategy, paint_by_example_steps=30)
|
||||
assert_equal(
|
||||
model,
|
||||
@ -50,9 +50,22 @@ def test_paint_by_example(strategy):
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
def test_paint_by_example_disable_nsfw(strategy):
|
||||
model = ModelManager(name="paint_by_example", device=device, disable_nsfw=False)
|
||||
cfg = get_config(strategy, paint_by_example_steps=30)
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"paint_by_example_{strategy.capitalize()}_disable_nsfw.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
def test_paint_by_example_sd_scale(strategy):
|
||||
model = ModelManager(name="paint_by_example", device=device)
|
||||
model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True)
|
||||
cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
|
||||
assert_equal(
|
||||
model,
|
||||
@ -67,7 +80,7 @@ def test_paint_by_example_sd_scale(strategy):
|
||||
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
def test_paint_by_example_cpu_offload(strategy):
|
||||
model = ModelManager(name="paint_by_example", device=device, cpu_offload=True)
|
||||
model = ModelManager(name="paint_by_example", device=device, cpu_offload=True, disable_nsfw=False)
|
||||
cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
|
||||
assert_equal(
|
||||
model,
|
||||
@ -75,14 +88,12 @@ def test_paint_by_example_cpu_offload(strategy):
|
||||
f"paint_by_example_{strategy.capitalize()}_cpu_offload.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fy=0.9,
|
||||
fx=1.3
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
def test_paint_by_example_cpu_offload_cpu_device(strategy):
|
||||
model = ModelManager(name="paint_by_example", device = torch.device('cpu'), cpu_offload=True)
|
||||
model = ModelManager(name="paint_by_example", device=torch.device('cpu'), cpu_offload=True, disable_nsfw=True)
|
||||
cfg = get_config(strategy, paint_by_example_steps=1, sd_scale=0.85)
|
||||
assert_equal(
|
||||
model,
|
||||
|
@ -14,7 +14,7 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
device = torch.device(device)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sd_device", ['cpu', 'cuda'])
|
||||
@pytest.mark.parametrize("sd_device", ['cuda'])
|
||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||
@pytest.mark.parametrize("cpu_textencoder", [True, False])
|
||||
@ -31,7 +31,7 @@ def test_runway_sd_1_5_ddim(sd_device, strategy, sampler, cpu_textencoder, disab
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
sd_run_local=True,
|
||||
sd_disable_nsfw=disable_nsfw,
|
||||
disable_nsfw=disable_nsfw,
|
||||
sd_cpu_textencoder=cpu_textencoder,
|
||||
callback=callback)
|
||||
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps)
|
||||
@ -66,7 +66,7 @@ def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_ns
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
sd_run_local=True,
|
||||
sd_disable_nsfw=disable_nsfw,
|
||||
disable_nsfw=disable_nsfw,
|
||||
sd_cpu_textencoder=cpu_textencoder,
|
||||
callback=callback)
|
||||
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps)
|
||||
@ -99,7 +99,7 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
sd_run_local=True,
|
||||
sd_disable_nsfw=False,
|
||||
disable_nsfw=False,
|
||||
sd_cpu_textencoder=False,
|
||||
callback=callback)
|
||||
cfg = get_config(
|
||||
@ -137,7 +137,7 @@ def test_runway_sd_1_5_sd_scale(sd_device, strategy, sampler, cpu_textencoder, d
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
sd_run_local=True,
|
||||
sd_disable_nsfw=disable_nsfw,
|
||||
disable_nsfw=disable_nsfw,
|
||||
sd_cpu_textencoder=cpu_textencoder)
|
||||
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85)
|
||||
cfg.sd_sampler = sampler
|
||||
@ -166,7 +166,7 @@ def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
sd_run_local=True,
|
||||
sd_disable_nsfw=False,
|
||||
disable_nsfw=True,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=True)
|
||||
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85)
|
||||
@ -191,7 +191,7 @@ def test_runway_sd_1_5_cpu_offload_cpu_device(sd_device, strategy, sampler):
|
||||
device=torch.device(sd_device),
|
||||
hf_access_token="",
|
||||
sd_run_local=True,
|
||||
sd_disable_nsfw=False,
|
||||
disable_nsfw=False,
|
||||
sd_cpu_textencoder=False,
|
||||
cpu_offload=True)
|
||||
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=1, sd_scale=0.85)
|
||||
|
Loading…
Reference in New Issue
Block a user