update test

This commit is contained in:
Qing 2022-12-11 20:27:32 +08:00
parent 03965c69e6
commit 41e2858a7c
2 changed files with 3 additions and 5 deletions

View File

@ -38,7 +38,7 @@ def assert_equal(
@pytest.mark.parametrize("strategy", [HDStrategy.ORIGINAL])
def test_paint_by_example(strategy):
model = ModelManager(name="paint_by_example", device=device)
cfg = get_config(strategy, paint_by_example_steps=30)
cfg = get_config(strategy, paint_by_example_steps=30 if device == 'cuda' else 1)
assert_equal(
model,
cfg,

View File

@ -1,12 +1,10 @@
import os
from pathlib import Path
import cv2
import pytest
import torch
from lama_cleaner.model_manager import ModelManager
from lama_cleaner.schema import Config, HDStrategy, LDMSampler, SDSampler
from lama_cleaner.schema import HDStrategy, SDSampler
from lama_cleaner.tests.test_model import get_config, assert_equal
current_dir = Path(__file__).parent.absolute().resolve()
@ -96,7 +94,7 @@ def test_runway_sd_1_5_negative_prompt(sd_device, strategy, sampler):
if sd_device == 'cuda' and not torch.cuda.is_available():
return
sd_steps = 50
sd_steps = 50 if sd_device == 'cuda' else 1
model = ModelManager(name="sd1.5",
device=torch.device(sd_device),
hf_access_token="",