test add non square test
This commit is contained in:
parent
6e164c4915
commit
8c1162a9e3
@ -11,10 +11,13 @@ from lama_cleaner.schema import Config, HDStrategy, LDMSampler
|
||||
current_dir = Path(__file__).parent.absolute().resolve()
|
||||
|
||||
|
||||
def get_data():
|
||||
img = cv2.imread(str(current_dir / 'image.png'))
|
||||
def get_data(fx=1):
|
||||
img = cv2.imread(str(current_dir / "image.png"))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
||||
mask = cv2.imread(str(current_dir / 'mask.png'), cv2.IMREAD_GRAYSCALE)
|
||||
mask = cv2.imread(str(current_dir / "mask.png"), cv2.IMREAD_GRAYSCALE)
|
||||
if fx != 1:
|
||||
img = cv2.resize(img, None, fx=fx, fy=1)
|
||||
mask = cv2.resize(mask, None, fx=fx, fy=1)
|
||||
return img, mask
|
||||
|
||||
|
||||
@ -31,11 +34,14 @@ def get_config(strategy, **kwargs):
|
||||
return Config(**data)
|
||||
|
||||
|
||||
def assert_equal(model, config, gt_name):
|
||||
img, mask = get_data()
|
||||
def assert_equal(model, config, gt_name, fx=1):
|
||||
img, mask = get_data(fx=fx)
|
||||
res = model(img, mask, config)
|
||||
cv2.imwrite(str(current_dir / gt_name), res,
|
||||
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0])
|
||||
cv2.imwrite(
|
||||
str(current_dir / gt_name),
|
||||
res,
|
||||
[int(cv2.IMWRITE_JPEG_QUALITY), 100, int(cv2.IMWRITE_PNG_COMPRESSION), 0],
|
||||
)
|
||||
|
||||
"""
|
||||
Note that JPEG is lossy compression, so even if it is the highest quality 100,
|
||||
@ -46,25 +52,65 @@ def assert_equal(model, config, gt_name):
|
||||
# assert np.array_equal(res, gt)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP])
|
||||
@pytest.mark.parametrize(
|
||||
"strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
|
||||
)
|
||||
def test_lama(strategy):
|
||||
model = ModelManager(name='lama', device='cpu')
|
||||
assert_equal(model, get_config(strategy), f'lama_{strategy[0].upper() + strategy[1:]}_result.png')
|
||||
model = ModelManager(name="lama", device="cpu")
|
||||
assert_equal(
|
||||
model,
|
||||
get_config(strategy),
|
||||
f"lama_{strategy[0].upper() + strategy[1:]}_result.png",
|
||||
)
|
||||
|
||||
fx = 1.3
|
||||
assert_equal(
|
||||
model,
|
||||
get_config(strategy),
|
||||
f"lama_{strategy[0].upper() + strategy[1:]}_fx_{fx}_result.png",
|
||||
fx=1.3,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP])
|
||||
@pytest.mark.parametrize('ldm_sampler', [LDMSampler.ddim, LDMSampler.plms])
|
||||
@pytest.mark.parametrize(
|
||||
"strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
|
||||
)
|
||||
@pytest.mark.parametrize("ldm_sampler", [LDMSampler.ddim, LDMSampler.plms])
|
||||
def test_ldm(strategy, ldm_sampler):
|
||||
model = ModelManager(name='ldm', device='cpu')
|
||||
model = ModelManager(name="ldm", device="cpu")
|
||||
cfg = get_config(strategy, ldm_sampler=ldm_sampler)
|
||||
assert_equal(model, cfg, f'ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_result.png')
|
||||
assert_equal(
|
||||
model, cfg, f"ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_result.png"
|
||||
)
|
||||
|
||||
fx = 1.3
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_fx_{fx}_result.png",
|
||||
fx=fx,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('strategy', [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP])
|
||||
@pytest.mark.parametrize('zits_wireframe', [False, True])
|
||||
@pytest.mark.parametrize(
|
||||
"strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
|
||||
)
|
||||
@pytest.mark.parametrize("zits_wireframe", [False, True])
|
||||
def test_zits(strategy, zits_wireframe):
|
||||
model = ModelManager(name='zits', device='cpu')
|
||||
model = ModelManager(name="zits", device="cpu")
|
||||
cfg = get_config(strategy, zits_wireframe=zits_wireframe)
|
||||
# os.environ['ZITS_DEBUG_LINE_PATH'] = str(current_dir / 'zits_debug_line.jpg')
|
||||
# os.environ['ZITS_DEBUG_EDGE_PATH'] = str(current_dir / 'zits_debug_edge.jpg')
|
||||
assert_equal(model, cfg, f'zits_{strategy[0].upper() + strategy[1:]}_wireframe_{zits_wireframe}_result.png')
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"zits_{strategy[0].upper() + strategy[1:]}_wireframe_{zits_wireframe}_result.png",
|
||||
)
|
||||
|
||||
fx = 1.3
|
||||
assert_equal(
|
||||
model,
|
||||
cfg,
|
||||
f"zits_{strategy[0].upper() + strategy[1:]}_wireframe_{zits_wireframe}_fx_{fx}_result.png",
|
||||
fx=fx,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user