diff --git a/lama_cleaner/tests/test_model.py b/lama_cleaner/tests/test_model.py index 40d484f..bfccac8 100644 --- a/lama_cleaner/tests/test_model.py +++ b/lama_cleaner/tests/test_model.py @@ -11,10 +11,13 @@ from lama_cleaner.schema import Config, HDStrategy, LDMSampler current_dir = Path(__file__).parent.absolute().resolve() -def get_data(): - img = cv2.imread(str(current_dir / 'image.png')) +def get_data(fx=1): + img = cv2.imread(str(current_dir / "image.png")) img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) - mask = cv2.imread(str(current_dir / 'mask.png'), cv2.IMREAD_GRAYSCALE) + mask = cv2.imread(str(current_dir / "mask.png"), cv2.IMREAD_GRAYSCALE) + if fx != 1: + img = cv2.resize(img, None, fx=fx, fy=1) + mask = cv2.resize(mask, None, fx=fx, fy=1) return img, mask @@ -31,11 +34,14 @@ def get_config(strategy, **kwargs): return Config(**data) -def assert_equal(model, config, gt_name): - img, mask = get_data() +def assert_equal(model, config, gt_name, fx=1): + img, mask = get_data(fx=fx) res = model(img, mask, config) - cv2.imwrite(str(current_dir / gt_name), res, - [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0]) + cv2.imwrite( + str(current_dir / gt_name), + res, + [int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0], + ) """ Note that JPEG is lossy compression, so even if it is the highest quality 100, @@ -46,25 +52,65 @@ def assert_equal(model, config, gt_name): # assert np.array_equal(res, gt) -@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]) +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] +) def test_lama(strategy): - model = ModelManager(name='lama', device='cpu') - assert_equal(model, get_config(strategy), f'lama_{strategy[0].upper() + strategy[1:]}_result.png') + model = ModelManager(name="lama", device="cpu") + assert_equal( + model, + get_config(strategy), + f"lama_{strategy[0].upper() + strategy[1:]}_result.png", + ) + + fx = 1.3 + assert_equal( + model, + get_config(strategy), + f"lama_{strategy[0].upper() + strategy[1:]}_fx_{fx}_result.png", + fx=1.3, + ) -@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]) -@pytest.mark.parametrize('ldm_sampler', [LDMSampler.ddim, LDMSampler.plms]) +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] +) +@pytest.mark.parametrize("ldm_sampler", [LDMSampler.ddim, LDMSampler.plms]) def test_ldm(strategy, ldm_sampler): - model = ModelManager(name='ldm', device='cpu') + model = ModelManager(name="ldm", device="cpu") cfg = get_config(strategy, ldm_sampler=ldm_sampler) - assert_equal(model, cfg, f'ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_result.png') + assert_equal( + model, cfg, f"ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_result.png" + ) + + fx = 1.3 + assert_equal( + model, + cfg, + f"ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_fx_{fx}_result.png", + fx=fx, + ) -@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]) -@pytest.mark.parametrize('zits_wireframe', [False, True]) +@pytest.mark.parametrize( + "strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP] +) +@pytest.mark.parametrize("zits_wireframe", [False, True]) def test_zits(strategy, zits_wireframe): - model = ModelManager(name='zits', device='cpu') + model = ModelManager(name="zits", device="cpu") cfg = get_config(strategy, zits_wireframe=zits_wireframe) # os.environ['ZITS_DEBUG_LINE_PATH'] = str(current_dir / 'zits_debug_line.jpg') # os.environ['ZITS_DEBUG_EDGE_PATH'] = str(current_dir / 'zits_debug_edge.jpg') - assert_equal(model, cfg, f'zits_{strategy[0].upper() + strategy[1:]}_wireframe_{zits_wireframe}_result.png') + assert_equal( + model, + cfg, + f"zits_{strategy[0].upper() + strategy[1:]}_wireframe_{zits_wireframe}_result.png", + ) + + fx = 1.3 + assert_equal( + model, + cfg, + f"zits_{strategy[0].upper() + strategy[1:]}_wireframe_{zits_wireframe}_fx_{fx}_result.png", + fx=fx, + )