optimize sd/paint_by_example modle VRAM usage
This commit is contained in:
parent
384f16dcd0
commit
148e97e8da
@ -102,7 +102,7 @@ if __name__ == "__main__":
|
|||||||
name=args.name,
|
name=args.name,
|
||||||
device=device,
|
device=device,
|
||||||
sd_run_local=True,
|
sd_run_local=True,
|
||||||
sd_disable_nsfw=True,
|
disable_nsfw=True,
|
||||||
sd_cpu_textencoder=True,
|
sd_cpu_textencoder=True,
|
||||||
hf_access_token="123"
|
hf_access_token="123"
|
||||||
)
|
)
|
||||||
|
@ -22,16 +22,30 @@ class PaintByExample(InpaintModel):
|
|||||||
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
|
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
|
||||||
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||||
model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
|
model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)}
|
||||||
|
|
||||||
|
if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
|
||||||
|
logger.info("Disable Paint By Example Model NSFW checker")
|
||||||
|
model_kwargs.update(dict(
|
||||||
|
safety_checker=None,
|
||||||
|
requires_safety_checker=False
|
||||||
|
))
|
||||||
|
|
||||||
self.model = DiffusionPipeline.from_pretrained(
|
self.model = DiffusionPipeline.from_pretrained(
|
||||||
"Fantasy-Studio/Paint-by-Example",
|
"Fantasy-Studio/Paint-by-Example",
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
**model_kwargs
|
**model_kwargs
|
||||||
)
|
)
|
||||||
self.model = self.model.to(device)
|
|
||||||
self.model.enable_attention_slicing()
|
self.model.enable_attention_slicing()
|
||||||
|
if kwargs.get('enable_xformers', False):
|
||||||
|
self.model.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
# TODO: gpu_id
|
# TODO: gpu_id
|
||||||
if kwargs.get('cpu_offload', False) and torch.cuda.is_available():
|
if kwargs.get('cpu_offload', False) and use_gpu:
|
||||||
|
self.model.image_encoder = self.model.image_encoder.to(device)
|
||||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
|
else:
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
|
||||||
def forward(self, image, mask, config: Config):
|
def forward(self, image, mask, config: Config):
|
||||||
"""Input image and output image have same size
|
"""Input image and output image have same size
|
||||||
|
@ -37,10 +37,12 @@ class SD(InpaintModel):
|
|||||||
fp16 = not kwargs.get('no_half', False)
|
fp16 = not kwargs.get('no_half', False)
|
||||||
|
|
||||||
model_kwargs = {"local_files_only": kwargs.get('local_files_only', kwargs['sd_run_local'])}
|
model_kwargs = {"local_files_only": kwargs.get('local_files_only', kwargs['sd_run_local'])}
|
||||||
if kwargs['sd_disable_nsfw']:
|
if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False):
|
||||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||||
model_kwargs.update(dict(
|
model_kwargs.update(dict(
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False
|
||||||
))
|
))
|
||||||
|
|
||||||
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
|
use_gpu = device == torch.device('cuda') and torch.cuda.is_available()
|
||||||
@ -52,19 +54,19 @@ class SD(InpaintModel):
|
|||||||
use_auth_token=kwargs["hf_access_token"],
|
use_auth_token=kwargs["hf_access_token"],
|
||||||
**model_kwargs
|
**model_kwargs
|
||||||
)
|
)
|
||||||
self.model = self.model.to(device)
|
|
||||||
|
|
||||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||||
self.model.enable_attention_slicing()
|
self.model.enable_attention_slicing()
|
||||||
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
|
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
|
||||||
if kwargs.get('sd_enable_xformers', False):
|
if kwargs.get('enable_xformers', False):
|
||||||
self.model.enable_xformers_memory_efficient_attention()
|
self.model.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
if kwargs.get('cpu_offload', False) and torch.cuda.is_available():
|
if kwargs.get('cpu_offload', False) and use_gpu:
|
||||||
# TODO: gpu_id
|
# TODO: gpu_id
|
||||||
logger.info("Enable sequential cpu offload")
|
logger.info("Enable sequential cpu offload")
|
||||||
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
else:
|
else:
|
||||||
|
self.model = self.model.to(device)
|
||||||
if kwargs['sd_cpu_textencoder']:
|
if kwargs['sd_cpu_textencoder']:
|
||||||
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||||
self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder, torch_dtype)
|
self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder, torch_dtype)
|
||||||
|
@ -28,6 +28,11 @@ def parse_args():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Disable Stable Diffusion NSFW checker",
|
help="Disable Stable Diffusion NSFW checker",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--disable-nsfw",
|
||||||
|
action="store_true",
|
||||||
|
help="Disable Stable Diffusion/Paint By Example NSFW checker",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sd-cpu-textencoder",
|
"--sd-cpu-textencoder",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@ -48,6 +53,11 @@ def parse_args():
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers"
|
help="Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers"
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--enable-xformers",
|
||||||
|
action="store_true",
|
||||||
|
help="sd/paint_by_example model. Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers"
|
||||||
|
)
|
||||||
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu", "mps"])
|
parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu", "mps"])
|
||||||
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
parser.add_argument("--gui", action="store_true", help="Launch as desktop app")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -78,7 +88,8 @@ def parse_args():
|
|||||||
if args.device == "cuda":
|
if args.device == "cuda":
|
||||||
import torch
|
import torch
|
||||||
if torch.cuda.is_available() is False:
|
if torch.cuda.is_available() is False:
|
||||||
parser.error("torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation")
|
parser.error(
|
||||||
|
"torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation")
|
||||||
|
|
||||||
if args.model_dir is not None:
|
if args.model_dir is not None:
|
||||||
if os.path.isfile(args.model_dir):
|
if os.path.isfile(args.model_dir):
|
||||||
@ -90,7 +101,6 @@ def parse_args():
|
|||||||
|
|
||||||
os.environ["XDG_CACHE_HOME"] = args.model_dir
|
os.environ["XDG_CACHE_HOME"] = args.model_dir
|
||||||
|
|
||||||
|
|
||||||
if args.input is not None:
|
if args.input is not None:
|
||||||
if not os.path.exists(args.input):
|
if not os.path.exists(args.input):
|
||||||
parser.error(f"invalid --input: {args.input} not exists")
|
parser.error(f"invalid --input: {args.input} not exists")
|
||||||
|
@ -392,12 +392,12 @@ def main(args):
|
|||||||
device=device,
|
device=device,
|
||||||
no_half=args.no_half,
|
no_half=args.no_half,
|
||||||
hf_access_token=args.hf_access_token,
|
hf_access_token=args.hf_access_token,
|
||||||
sd_disable_nsfw=args.sd_disable_nsfw,
|
disable_nsfw=args.sd_disable_nsfw or args.disable_nsfw,
|
||||||
sd_cpu_textencoder=args.sd_cpu_textencoder,
|
sd_cpu_textencoder=args.sd_cpu_textencoder,
|
||||||
sd_run_local=args.sd_run_local,
|
sd_run_local=args.sd_run_local,
|
||||||
local_files_only=args.local_files_only,
|
local_files_only=args.local_files_only,
|
||||||
cpu_offload=args.cpu_offload,
|
cpu_offload=args.cpu_offload,
|
||||||
sd_enable_xformers=args.sd_enable_xformers,
|
enable_xformers=args.sd_enable_xformers or args.enable_xformers,
|
||||||
callback=diffuser_callback,
|
callback=diffuser_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -37,7 +37,7 @@ def assert_equal(
|
|||||||
|
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
def test_paint_by_example(strategy):
|
def test_paint_by_example(strategy):
|
||||||
model = ModelManager(name="paint_by_example", device=device)
|
model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True)
|
||||||
cfg = get_config(strategy, paint_by_example_steps=30)
|
cfg = get_config(strategy, paint_by_example_steps=30)
|
||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
@ -50,9 +50,22 @@ def test_paint_by_example(strategy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
|
def test_paint_by_example_disable_nsfw(strategy):
|
||||||
|
model = ModelManager(name="paint_by_example", device=device, disable_nsfw=False)
|
||||||
|
cfg = get_config(strategy, paint_by_example_steps=30)
|
||||||
|
assert_equal(
|
||||||
|
model,
|
||||||
|
cfg,
|
||||||
|
f"paint_by_example_{strategy.capitalize()}_disable_nsfw.png",
|
||||||
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
def test_paint_by_example_sd_scale(strategy):
|
def test_paint_by_example_sd_scale(strategy):
|
||||||
model = ModelManager(name="paint_by_example", device=device)
|
model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True)
|
||||||
cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
|
cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
|
||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
@ -67,7 +80,7 @@ def test_paint_by_example_sd_scale(strategy):
|
|||||||
|
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
def test_paint_by_example_cpu_offload(strategy):
|
def test_paint_by_example_cpu_offload(strategy):
|
||||||
model = ModelManager(name="paint_by_example", device=device, cpu_offload=True)
|
model = ModelManager(name="paint_by_example", device=device, cpu_offload=True, disable_nsfw=False)
|
||||||
cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
|
cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85)
|
||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
@ -75,14 +88,12 @@ def test_paint_by_example_cpu_offload(strategy):
|
|||||||
f"paint_by_example_{strategy.capitalize()}_cpu_offload.png",
|
f"paint_by_example_{strategy.capitalize()}_cpu_offload.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
fy=0.9,
|
|
||||||
fx=1.3
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
def test_paint_by_example_cpu_offload_cpu_device(strategy):
|
def test_paint_by_example_cpu_offload_cpu_device(strategy):
|
||||||
model = ModelManager(name="paint_by_example", device = torch.device('cpu'), cpu_offload=True)
|
model = ModelManager(name="paint_by_example", device=torch.device('cpu'), cpu_offload=True, disable_nsfw=True)
|
||||||
cfg = get_config(strategy, paint_by_example_steps=1, sd_scale=0.85)
|
cfg = get_config(strategy, paint_by_example_steps=1, sd_scale=0.85)
|
||||||
assert_equal(
|
assert_equal(
|
||||||
model,
|
model,
|
||||||
|
@ -14,7 +14,7 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|||||||
device = torch.device(device)
|
device = torch.device(device)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sd_device", ['cpu', 'cuda'])
|
@pytest.mark.parametrize("sd_device", ['cuda'])
|
||||||
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
|
||||||
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
@pytest.mark.parametrize("sampler", [SDSampler.ddim])
|
||||||
@pytest.mark.parametrize("cpu_textencoder", [True, False])
|
@pytest.mark.parametrize("cpu_textencoder", [True, False])
|
||||||
@ -31,7 +31,7 @@ def test_runway_sd_1_5_ddim(sd_device, strategy, sampler, cpu_textencoder, disab
|
|||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
sd_run_local=True,
|
||||||
sd_disable_nsfw=disable_nsfw,
|
disable_nsfw=disable_nsfw,
|
||||||
sd_cpu_textencoder=cpu_textencoder,
|
sd_cpu_textencoder=cpu_textencoder,
|
||||||
callback=callback)
|
callback=callback)
|
||||||
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps)
|
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps)
|
||||||
@ -66,7 +66,7 @@ def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_ns
|
|||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
sd_run_local=True,
|
||||||
sd_disable_nsfw=disable_nsfw,
|
disable_nsfw=disable_nsfw,
|
||||||
sd_cpu_textencoder=cpu_textencoder,
|
sd_cpu_textencoder=cpu_textencoder,
|
||||||
callback=callback)
|
callback=callback)
|
||||||
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps)
|
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps)
|
||||||
@ -99,7 +99,7 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
|
|||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
sd_run_local=True,
|
||||||
sd_disable_nsfw=False,
|
disable_nsfw=False,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
callback=callback)
|
callback=callback)
|
||||||
cfg = get_config(
|
cfg = get_config(
|
||||||
@ -137,7 +137,7 @@ def test_runway_sd_1_5_sd_scale(sd_device, strategy, sampler, cpu_textencoder, d
|
|||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
sd_run_local=True,
|
||||||
sd_disable_nsfw=disable_nsfw,
|
disable_nsfw=disable_nsfw,
|
||||||
sd_cpu_textencoder=cpu_textencoder)
|
sd_cpu_textencoder=cpu_textencoder)
|
||||||
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85)
|
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85)
|
||||||
cfg.sd_sampler = sampler
|
cfg.sd_sampler = sampler
|
||||||
@ -166,7 +166,7 @@ def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler):
|
|||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
sd_run_local=True,
|
||||||
sd_disable_nsfw=False,
|
disable_nsfw=True,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
cpu_offload=True)
|
cpu_offload=True)
|
||||||
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85)
|
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85)
|
||||||
@ -191,7 +191,7 @@ def test_runway_sd_1_5_cpu_offload_cpu_device(sd_device, strategy, sampler):
|
|||||||
device=torch.device(sd_device),
|
device=torch.device(sd_device),
|
||||||
hf_access_token="",
|
hf_access_token="",
|
||||||
sd_run_local=True,
|
sd_run_local=True,
|
||||||
sd_disable_nsfw=False,
|
disable_nsfw=False,
|
||||||
sd_cpu_textencoder=False,
|
sd_cpu_textencoder=False,
|
||||||
cpu_offload=True)
|
cpu_offload=True)
|
||||||
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=1, sd_scale=0.85)
|
cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=1, sd_scale=0.85)
|
||||||
|
Loading…
Reference in New Issue
Block a user