update test

This commit is contained in:
Qing 2023-03-30 21:06:07 +08:00
parent 6911d3ce16
commit 07ae89b7c0

View File

@ -4,7 +4,12 @@ import cv2
import pytest import pytest
import torch.cuda import torch.cuda
from lama_cleaner.plugins import RemoveBG, RealESRGANUpscaler, GFPGANPlugin from lama_cleaner.plugins import (
RemoveBG,
RealESRGANUpscaler,
GFPGANPlugin,
RestoreFormerPlugin,
)
current_dir = Path(__file__).parent.absolute().resolve() current_dir = Path(__file__).parent.absolute().resolve()
save_dir = current_dir / "result" save_dir = current_dir / "result"
@ -48,3 +53,14 @@ def test_gfpgan(device):
model = GFPGANPlugin(device) model = GFPGANPlugin(device)
res = model(rgb_img, None, None) res = model(rgb_img, None, None)
_save(res, f"test_gfpgan_{device}.png") _save(res, f"test_gfpgan_{device}.png")
@pytest.mark.parametrize("device", ["cuda", "cpu", "mps"])
def test_restoreformer(device):
if device == "cuda" and not torch.cuda.is_available():
return
if device == "mps" and not torch.backends.mps.is_available():
return
model = RestoreFormerPlugin(device)
res = model(rgb_img, None, None)
_save(res, f"test_restoreformer_{device}.png")