From 23943b0ebdc8e30a541087f1756ca7d67f60ced0 Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 5 Sep 2022 13:07:25 +0800 Subject: [PATCH] update test --- lama_cleaner/tests/test_model.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/lama_cleaner/tests/test_model.py b/lama_cleaner/tests/test_model.py index 3050891..8ade0c2 100644 --- a/lama_cleaner/tests/test_model.py +++ b/lama_cleaner/tests/test_model.py @@ -2,11 +2,13 @@ from pathlib import Path import cv2 import pytest +import torch from lama_cleaner.model_manager import ModelManager from lama_cleaner.schema import Config, HDStrategy, LDMSampler current_dir = Path(__file__).parent.absolute().resolve() +device = 'cuda' if torch.cuda.is_available() else 'cpu' def get_data(fx=1, fy=1.0): @@ -55,7 +57,7 @@ def assert_equal(model, config, gt_name, fx=1, fy=1): "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] ) def test_lama(strategy): - model = ModelManager(name="lama", device="cpu") + model = ModelManager(name="lama", device=device) assert_equal( model, get_config(strategy), @@ -76,7 +78,7 @@ def test_lama(strategy): ) @pytest.mark.parametrize("ldm_sampler", [LDMSampler.ddim, LDMSampler.plms]) def test_ldm(strategy, ldm_sampler): - model = ModelManager(name="ldm", device="cpu") + model = ModelManager(name="ldm", device=device) cfg = get_config(strategy, ldm_sampler=ldm_sampler) assert_equal( model, cfg, f"ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_result.png" @@ -96,7 +98,7 @@ def test_ldm(strategy, ldm_sampler): ) @pytest.mark.parametrize("zits_wireframe", [False, True]) def test_zits(strategy, zits_wireframe): - model = ModelManager(name="zits", device="cpu") + model = ModelManager(name="zits", device=device) cfg = get_config(strategy, zits_wireframe=zits_wireframe) # os.environ['ZITS_DEBUG_LINE_PATH'] = str(current_dir / 'zits_debug_line.jpg') # os.environ['ZITS_DEBUG_EDGE_PATH'] = str(current_dir / 'zits_debug_edge.jpg') @@ -119,7 +121,7 @@ def test_zits(strategy, zits_wireframe): "strategy", [HDStrategy.ORIGINAL] ) def test_mat(strategy): - model = ModelManager(name="mat", device="cpu") + model = ModelManager(name="mat", device=device) cfg = get_config(strategy) assert_equal( @@ -133,7 +135,7 @@ def test_mat(strategy): "strategy", [HDStrategy.ORIGINAL] ) def test_fcf(strategy): - model = ModelManager(name="fcf", device="cpu") + model = ModelManager(name="fcf", device=device) cfg = get_config(strategy) assert_equal(