From c2005786d7a9b6e62897e1a4a8aac9fe2a042e9f Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 14 Nov 2022 18:19:50 +0800 Subject: [PATCH] fix slow sd test --- lama_cleaner/model_manager.py | 4 +++- lama_cleaner/tests/test_model.py | 9 +++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index 602fc10..f3b8eda 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -1,3 +1,5 @@ +import torch + from lama_cleaner.model.fcf import FcF from lama_cleaner.model.lama import LaMa from lama_cleaner.model.ldm import LDM @@ -11,7 +13,7 @@ models = {"lama": LaMa, "ldm": LDM, "zits": ZITS, "mat": MAT, "fcf": FcF, "sd1.5 class ModelManager: - def __init__(self, name: str, device, **kwargs): + def __init__(self, name: str, device: torch.device, **kwargs): self.name = name self.device = device self.kwargs = kwargs diff --git a/lama_cleaner/tests/test_model.py b/lama_cleaner/tests/test_model.py index fd7ccb1..2b24bfe 100644 --- a/lama_cleaner/tests/test_model.py +++ b/lama_cleaner/tests/test_model.py @@ -12,6 +12,7 @@ current_dir = Path(__file__).parent.absolute().resolve() save_dir = current_dir / 'result' save_dir.mkdir(exist_ok=True, parents=True) device = 'cuda' if torch.cuda.is_available() else 'cpu' +device = torch.device(device) def get_data(fx: float = 1, fy: float = 1.0, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"): @@ -172,9 +173,9 @@ def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_ns if sd_device == 'cuda' and not torch.cuda.is_available(): return - sd_steps = 1 + sd_steps = 50 model = ModelManager(name="sd1.5", - device=sd_device, + device=torch.device(sd_device), hf_access_token="", sd_run_local=True, sd_disable_nsfw=disable_nsfw, @@ -207,7 +208,7 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler): sd_steps = 50 model = ModelManager(name="sd1.5", - device=sd_device, + device=torch.device(sd_device), hf_access_token="", sd_run_local=True, sd_disable_nsfw=True, @@ -241,7 +242,7 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler): def test_cv2(strategy, cv2_flag, cv2_radius): model = ModelManager( name="cv2", - device=device, + device=torch.device(device), ) cfg = get_config(strategy, cv2_flag=cv2_flag, cv2_radius=cv2_radius) assert_equal(