From c4968dd0a96b2e304bf743beaa1917cee68c3aac Mon Sep 17 00:00:00 2001 From: Qing Date: Tue, 28 Mar 2023 16:36:41 +0800 Subject: [PATCH] fix gfpgan cpu/mps device --- lama_cleaner/parse_args.py | 2 +- lama_cleaner/plugins/gfpgan_plugin.py | 4 ++++ lama_cleaner/tests/test_plugins.py | 14 +++++++++----- 3 files changed, 14 insertions(+), 6 deletions(-) diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index df9fe98..c710b51 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -107,7 +107,7 @@ def parse_args(): help="Enable GFPGAN face restore", ) parser.add_argument( - "--gfpgan-device", default="cpu", type=str, choices=["cpu", "cuda"] + "--gfpgan-device", default="cpu", type=str, choices=["cpu", "cuda", "mps"] ) parser.add_argument( "--enable-gif", diff --git a/lama_cleaner/plugins/gfpgan_plugin.py b/lama_cleaner/plugins/gfpgan_plugin.py index cee1834..17032fb 100644 --- a/lama_cleaner/plugins/gfpgan_plugin.py +++ b/lama_cleaner/plugins/gfpgan_plugin.py @@ -17,6 +17,10 @@ class GFPGANPlugin(BasePlugin): model_path = download_model(url, model_md5) logger.info(f"GFPGAN model path: {model_path}") + import facexlib + if hasattr(facexlib.detection.retinaface, "device"): + facexlib.detection.retinaface.device = device + # Use GFPGAN for face enhancement self.face_enhancer = MyGFPGANer( model_path=model_path, diff --git a/lama_cleaner/tests/test_plugins.py b/lama_cleaner/tests/test_plugins.py index 6d68ec9..2f41ed8 100644 --- a/lama_cleaner/tests/test_plugins.py +++ b/lama_cleaner/tests/test_plugins.py @@ -24,23 +24,27 @@ def test_remove_bg(): _save(res, "test_remove_bg.png") -@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"]) def test_upscale(device): if device == "cuda" and not torch.cuda.is_available(): return + if device == "mps" and not torch.backends.mps.is_available(): + return model = RealESRGANUpscaler("realesr-general-x4v3", device) res = model.forward(bgr_img, 2) - _save(res, "test_upscale_x2.png") + _save(res, f"test_upscale_x2_{device}.png") res = model.forward(bgr_img, 4) - _save(res, "test_upscale_x4.png") + _save(res, f"test_upscale_x4_{device}.png") -@pytest.mark.parametrize("device", ["cuda", "cpu"]) +@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"]) def test_gfpgan(device): if device == "cuda" and not torch.cuda.is_available(): return + if device == "mps" and not torch.backends.mps.is_available(): + return model = GFPGANPlugin(device) res = model(rgb_img, None, None) - _save(res, "test_gfpgan.png") + _save(res, f"test_gfpgan_{device}.png")