IOPaint/lama_cleaner/tests/test_model.py

161 lines
4.9 KiB
Python
Raw Normal View History

2022-04-15 18:11:51 +02:00
import pytest
2022-09-05 07:07:25 +02:00
import torch
2022-04-15 18:11:51 +02:00
from lama_cleaner.model_manager import ModelManager
2023-12-28 03:48:52 +01:00
from lama_cleaner.schema import HDStrategy, LDMSampler
from lama_cleaner.tests.utils import assert_equal, get_config, current_dir, check_device
2022-04-15 18:11:51 +02:00
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
2023-03-26 05:45:39 +02:00
@pytest.mark.parametrize(
"strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
)
2023-12-28 03:48:52 +01:00
def test_lama(device, strategy):
check_device(device)
2023-03-26 05:45:39 +02:00
model = ModelManager(name="lama", device=device)
assert_equal(
model,
2023-12-28 03:48:52 +01:00
get_config(strategy=strategy),
2023-03-26 05:45:39 +02:00
f"lama_{strategy[0].upper() + strategy[1:]}_result.png",
)
fx = 1.3
assert_equal(
model,
2023-12-28 03:48:52 +01:00
get_config(strategy=strategy),
2023-03-26 05:45:39 +02:00
f"lama_{strategy[0].upper() + strategy[1:]}_fx_{fx}_result.png",
fx=1.3,
)
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize("device", ["cuda", "cpu"])
2023-03-26 05:45:39 +02:00
@pytest.mark.parametrize(
"strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
)
@pytest.mark.parametrize("ldm_sampler", [LDMSampler.ddim, LDMSampler.plms])
2023-12-28 03:48:52 +01:00
def test_ldm(device, strategy, ldm_sampler):
check_device(device)
2023-03-26 05:45:39 +02:00
model = ModelManager(name="ldm", device=device)
2023-12-28 03:48:52 +01:00
cfg = get_config(strategy=strategy, ldm_sampler=ldm_sampler)
2023-03-26 05:45:39 +02:00
assert_equal(
model, cfg, f"ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_result.png"
)
fx = 1.3
assert_equal(
model,
cfg,
f"ldm_{strategy[0].upper() + strategy[1:]}_{ldm_sampler}_fx_{fx}_result.png",
fx=fx,
)
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize("device", ["cuda", "cpu"])
2023-03-26 05:45:39 +02:00
@pytest.mark.parametrize(
"strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
)
@pytest.mark.parametrize("zits_wireframe", [False, True])
2023-12-28 03:48:52 +01:00
def test_zits(device, strategy, zits_wireframe):
check_device(device)
2023-03-26 05:45:39 +02:00
model = ModelManager(name="zits", device=device)
2023-12-28 03:48:52 +01:00
cfg = get_config(strategy=strategy, zits_wireframe=zits_wireframe)
2023-03-26 05:45:39 +02:00
assert_equal(
model,
cfg,
f"zits_{strategy[0].upper() + strategy[1:]}_wireframe_{zits_wireframe}_result.png",
)
fx = 1.3
assert_equal(
model,
cfg,
f"zits_{strategy.capitalize()}_wireframe_{zits_wireframe}_fx_{fx}_result.png",
fx=fx,
)
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize("device", ["cuda", "cpu"])
2023-03-27 14:54:05 +02:00
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
@pytest.mark.parametrize("no_half", [True, False])
2023-12-28 03:48:52 +01:00
def test_mat(device, strategy, no_half):
check_device(device)
2023-03-26 05:45:39 +02:00
model = ModelManager(name="mat", device=device, no_half=no_half)
2023-12-28 03:48:52 +01:00
cfg = get_config(strategy=strategy)
2022-08-22 17:24:02 +02:00
2023-12-28 03:48:52 +01:00
assert_equal(
model,
cfg,
f"mat_{strategy.capitalize()}_result.png",
)
2022-09-02 04:37:30 +02:00
2023-03-27 14:54:05 +02:00
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize("device", ["cuda", "cpu"])
2023-03-26 05:45:39 +02:00
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
2023-12-28 03:48:52 +01:00
def test_fcf(device, strategy):
check_device(device)
2023-03-26 05:45:39 +02:00
model = ModelManager(name="fcf", device=device)
2023-12-28 03:48:52 +01:00
cfg = get_config(strategy=strategy)
2023-03-26 05:45:39 +02:00
assert_equal(model, cfg, f"fcf_{strategy.capitalize()}_result.png", fx=2, fy=2)
assert_equal(model, cfg, f"fcf_{strategy.capitalize()}_result.png", fx=3.8, fy=2)
@pytest.mark.parametrize(
"strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
)
@pytest.mark.parametrize("cv2_flag", ["INPAINT_NS", "INPAINT_TELEA"])
@pytest.mark.parametrize("cv2_radius", [3, 15])
def test_cv2(strategy, cv2_flag, cv2_radius):
model = ModelManager(
name="cv2",
2023-12-28 03:48:52 +01:00
device=torch.device("cpu"),
2023-03-26 05:45:39 +02:00
)
2023-12-28 03:48:52 +01:00
cfg = get_config(strategy=strategy, cv2_flag=cv2_flag, cv2_radius=cv2_radius)
2023-03-26 05:45:39 +02:00
assert_equal(
model,
cfg,
2023-11-20 06:05:28 +01:00
f"cv2_{strategy.capitalize()}_{cv2_flag}_{cv2_radius}.png",
2023-03-26 05:45:39 +02:00
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize("device", ["cuda", "cpu"])
2023-03-26 05:45:39 +02:00
@pytest.mark.parametrize(
"strategy", [HDStrategy.ORIGINAL, HDStrategy.RESIZE, HDStrategy.CROP]
)
2023-12-28 03:48:52 +01:00
def test_manga(device, strategy):
check_device(device)
2023-03-26 05:45:39 +02:00
model = ModelManager(
name="manga",
device=torch.device(device),
)
2023-12-28 03:48:52 +01:00
cfg = get_config(strategy=strategy)
2023-03-26 05:45:39 +02:00
assert_equal(
model,
cfg,
2023-11-20 06:05:28 +01:00
f"manga_{strategy.capitalize()}.png",
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
)
2023-12-28 03:48:52 +01:00
@pytest.mark.parametrize("device", ["cuda", "mps", "cpu"])
2023-11-20 06:05:28 +01:00
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
2023-12-28 03:48:52 +01:00
def test_mi_gan(device, strategy):
check_device(device)
2023-11-20 06:05:28 +01:00
model = ModelManager(
name="migan",
device=torch.device(device),
)
2023-12-28 03:48:52 +01:00
cfg = get_config(strategy=strategy)
2023-11-20 06:05:28 +01:00
assert_equal(
model,
cfg,
2023-12-28 03:48:52 +01:00
f"migan_device_{device}.png",
2023-03-26 05:45:39 +02:00
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
2023-12-28 03:48:52 +01:00
fx=1.5,
fy=1.7
2023-03-26 05:45:39 +02:00
)