diff --git a/lama_cleaner/benchmark.py b/lama_cleaner/benchmark.py index a548a70..a0a170e 100644 --- a/lama_cleaner/benchmark.py +++ b/lama_cleaner/benchmark.py @@ -102,7 +102,7 @@ if __name__ == "__main__": name=args.name, device=device, sd_run_local=True, - sd_disable_nsfw=True, + disable_nsfw=True, sd_cpu_textencoder=True, hf_access_token="123" ) diff --git a/lama_cleaner/model/paint_by_example.py b/lama_cleaner/model/paint_by_example.py index bfd7b1e..b946a08 100644 --- a/lama_cleaner/model/paint_by_example.py +++ b/lama_cleaner/model/paint_by_example.py @@ -22,16 +22,30 @@ class PaintByExample(InpaintModel): use_gpu = device == torch.device('cuda') and torch.cuda.is_available() torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 model_kwargs = {"local_files_only": kwargs.get('local_files_only', False)} + + if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False): + logger.info("Disable Paint By Example Model NSFW checker") + model_kwargs.update(dict( + safety_checker=None, + requires_safety_checker=False + )) + self.model = DiffusionPipeline.from_pretrained( "Fantasy-Studio/Paint-by-Example", torch_dtype=torch_dtype, **model_kwargs ) - self.model = self.model.to(device) + self.model.enable_attention_slicing() + if kwargs.get('enable_xformers', False): + self.model.enable_xformers_memory_efficient_attention() + # TODO: gpu_id - if kwargs.get('cpu_offload', False) and torch.cuda.is_available(): + if kwargs.get('cpu_offload', False) and use_gpu: + self.model.image_encoder = self.model.image_encoder.to(device) self.model.enable_sequential_cpu_offload(gpu_id=0) + else: + self.model = self.model.to(device) def forward(self, image, mask, config: Config): """Input image and output image have same size diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index 88029b1..08585bf 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -37,10 +37,12 @@ class SD(InpaintModel): fp16 = not kwargs.get('no_half', False) model_kwargs = {"local_files_only": kwargs.get('local_files_only', kwargs['sd_run_local'])} - if kwargs['sd_disable_nsfw']: + if kwargs['disable_nsfw'] or kwargs.get('cpu_offload', False): logger.info("Disable Stable Diffusion Model NSFW checker") model_kwargs.update(dict( safety_checker=None, + feature_extractor=None, + requires_safety_checker=False )) use_gpu = device == torch.device('cuda') and torch.cuda.is_available() @@ -52,19 +54,19 @@ class SD(InpaintModel): use_auth_token=kwargs["hf_access_token"], **model_kwargs ) - self.model = self.model.to(device) # https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing self.model.enable_attention_slicing() # https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention - if kwargs.get('sd_enable_xformers', False): + if kwargs.get('enable_xformers', False): self.model.enable_xformers_memory_efficient_attention() - if kwargs.get('cpu_offload', False) and torch.cuda.is_available(): + if kwargs.get('cpu_offload', False) and use_gpu: # TODO: gpu_id logger.info("Enable sequential cpu offload") self.model.enable_sequential_cpu_offload(gpu_id=0) else: + self.model = self.model.to(device) if kwargs['sd_cpu_textencoder']: logger.info("Run Stable Diffusion TextEncoder on CPU") self.model.text_encoder = CPUTextEncoderWrapper(self.model.text_encoder, torch_dtype) diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index 1c607ed..87e3137 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -28,6 +28,11 @@ def parse_args(): action="store_true", help="Disable Stable Diffusion NSFW checker", ) + parser.add_argument( + "--disable-nsfw", + action="store_true", + help="Disable Stable Diffusion/Paint By Example NSFW checker", + ) parser.add_argument( "--sd-cpu-textencoder", action="store_true", @@ -48,6 +53,11 @@ def parse_args(): action="store_true", help="Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers" ) + parser.add_argument( + "--enable-xformers", + action="store_true", + help="sd/paint_by_example model. Enable xFormers optimizations. Requires that xformers package has been installed. See: https://github.com/facebookresearch/xformers" + ) parser.add_argument("--device", default="cuda", type=str, choices=["cuda", "cpu", "mps"]) parser.add_argument("--gui", action="store_true", help="Launch as desktop app") parser.add_argument( @@ -78,7 +88,8 @@ def parse_args(): if args.device == "cuda": import torch if torch.cuda.is_available() is False: - parser.error("torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation") + parser.error( + "torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation") if args.model_dir is not None: if os.path.isfile(args.model_dir): @@ -90,7 +101,6 @@ def parse_args(): os.environ["XDG_CACHE_HOME"] = args.model_dir - if args.input is not None: if not os.path.exists(args.input): parser.error(f"invalid --input: {args.input} not exists") diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index 4eae2d3..0c34430 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -392,12 +392,12 @@ def main(args): device=device, no_half=args.no_half, hf_access_token=args.hf_access_token, - sd_disable_nsfw=args.sd_disable_nsfw, + disable_nsfw=args.sd_disable_nsfw or args.disable_nsfw, sd_cpu_textencoder=args.sd_cpu_textencoder, sd_run_local=args.sd_run_local, local_files_only=args.local_files_only, cpu_offload=args.cpu_offload, - sd_enable_xformers=args.sd_enable_xformers, + enable_xformers=args.sd_enable_xformers or args.enable_xformers, callback=diffuser_callback, ) diff --git a/lama_cleaner/tests/test_paint_by_example.py b/lama_cleaner/tests/test_paint_by_example.py index 316c334..c495690 100644 --- a/lama_cleaner/tests/test_paint_by_example.py +++ b/lama_cleaner/tests/test_paint_by_example.py @@ -37,7 +37,7 @@ def assert_equal( @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) def test_paint_by_example(strategy): - model = ModelManager(name="paint_by_example", device=device) + model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True) cfg = get_config(strategy, paint_by_example_steps=30) assert_equal( model, @@ -50,9 +50,22 @@ def test_paint_by_example(strategy): ) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +def test_paint_by_example_disable_nsfw(strategy): + model = ModelManager(name="paint_by_example", device=device, disable_nsfw=False) + cfg = get_config(strategy, paint_by_example_steps=30) + assert_equal( + model, + cfg, + f"paint_by_example_{strategy.capitalize()}_disable_nsfw.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) def test_paint_by_example_sd_scale(strategy): - model = ModelManager(name="paint_by_example", device=device) + model = ModelManager(name="paint_by_example", device=device, disable_nsfw=True) cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85) assert_equal( model, @@ -67,7 +80,7 @@ def test_paint_by_example_sd_scale(strategy): @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) def test_paint_by_example_cpu_offload(strategy): - model = ModelManager(name="paint_by_example", device=device, cpu_offload=True) + model = ModelManager(name="paint_by_example", device=device, cpu_offload=True, disable_nsfw=False) cfg = get_config(strategy, paint_by_example_steps=30, sd_scale=0.85) assert_equal( model, @@ -75,14 +88,12 @@ def test_paint_by_example_cpu_offload(strategy): f"paint_by_example_{strategy.capitalize()}_cpu_offload.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", - fy=0.9, - fx=1.3 ) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) def test_paint_by_example_cpu_offload_cpu_device(strategy): - model = ModelManager(name="paint_by_example", device = torch.device('cpu'), cpu_offload=True) + model = ModelManager(name="paint_by_example", device=torch.device('cpu'), cpu_offload=True, disable_nsfw=True) cfg = get_config(strategy, paint_by_example_steps=1, sd_scale=0.85) assert_equal( model, diff --git a/lama_cleaner/tests/test_sd_model.py b/lama_cleaner/tests/test_sd_model.py index 4ee59f4..d643b28 100644 --- a/lama_cleaner/tests/test_sd_model.py +++ b/lama_cleaner/tests/test_sd_model.py @@ -14,7 +14,7 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu' device = torch.device(device) -@pytest.mark.parametrize("sd_device", ['cpu', 'cuda']) +@pytest.mark.parametrize("sd_device", ['cuda']) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("sampler", [SDSampler.ddim]) @pytest.mark.parametrize("cpu_textencoder", [True, False]) @@ -31,7 +31,7 @@ def test_runway_sd_1_5_ddim(sd_device, strategy, sampler, cpu_textencoder, disab device=torch.device(sd_device), hf_access_token="", sd_run_local=True, - sd_disable_nsfw=disable_nsfw, + disable_nsfw=disable_nsfw, sd_cpu_textencoder=cpu_textencoder, callback=callback) cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps) @@ -66,7 +66,7 @@ def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_ns device=torch.device(sd_device), hf_access_token="", sd_run_local=True, - sd_disable_nsfw=disable_nsfw, + disable_nsfw=disable_nsfw, sd_cpu_textencoder=cpu_textencoder, callback=callback) cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps) @@ -99,7 +99,7 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler): device=torch.device(sd_device), hf_access_token="", sd_run_local=True, - sd_disable_nsfw=False, + disable_nsfw=False, sd_cpu_textencoder=False, callback=callback) cfg = get_config( @@ -137,7 +137,7 @@ def test_runway_sd_1_5_sd_scale(sd_device, strategy, sampler, cpu_textencoder, d device=torch.device(sd_device), hf_access_token="", sd_run_local=True, - sd_disable_nsfw=disable_nsfw, + disable_nsfw=disable_nsfw, sd_cpu_textencoder=cpu_textencoder) cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85) cfg.sd_sampler = sampler @@ -166,7 +166,7 @@ def test_runway_sd_1_5_cpu_offload(sd_device, strategy, sampler): device=torch.device(sd_device), hf_access_token="", sd_run_local=True, - sd_disable_nsfw=False, + disable_nsfw=True, sd_cpu_textencoder=False, cpu_offload=True) cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=sd_steps, sd_scale=0.85) @@ -191,7 +191,7 @@ def test_runway_sd_1_5_cpu_offload_cpu_device(sd_device, strategy, sampler): device=torch.device(sd_device), hf_access_token="", sd_run_local=True, - sd_disable_nsfw=False, + disable_nsfw=False, sd_cpu_textencoder=False, cpu_offload=True) cfg = get_config(strategy, prompt='a fox sitting on a bench', sd_steps=1, sd_scale=0.85)