From 8f942e27c45600142d344e99094783057d7e1ca1 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 15 Nov 2023 17:20:44 +0800 Subject: [PATCH] add test --- lama_cleaner/tests/test_sd_model.py | 30 +++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/lama_cleaner/tests/test_sd_model.py b/lama_cleaner/tests/test_sd_model.py index 50135a0..3acba71 100644 --- a/lama_cleaner/tests/test_sd_model.py +++ b/lama_cleaner/tests/test_sd_model.py @@ -177,6 +177,36 @@ def test_runway_sd_1_5_sd_scale( ) +@pytest.mark.parametrize("sd_device", ["mps"]) +@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) +@pytest.mark.parametrize("sampler", [SDSampler.k_euler_a]) +def test_runway_sd_sd_strength(sd_device, strategy, sampler): + if sd_device == "cuda" and not torch.cuda.is_available(): + return + + sd_steps = 50 if sd_device == "cuda" else 20 + model = ModelManager( + name="sd1.5", + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=True, + disable_nsfw=True, + sd_cpu_textencoder=False, + ) + cfg = get_config( + strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps, sd_strength=0.8 + ) + cfg.sd_sampler = sampler + + assert_equal( + model, + cfg, + f"runway_sd_strength_0.8.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + @pytest.mark.parametrize("sd_device", ["cuda"]) @pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL]) @pytest.mark.parametrize("sampler", [SDSampler.k_euler_a])