update test
This commit is contained in:
parent
13b6371a53
commit
d4bd37682a
@ -18,8 +18,6 @@ def get_data(fx=1, fy=1.0, img_p=current_dir / "image.png", mask_p=current_dir /
|
||||
img = cv2.imread(str(img_p))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
||||
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
if fx != 1:
|
||||
img = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
|
||||
mask = cv2.resize(mask, None, fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST)
|
||||
return img, mask
|
||||
@ -40,6 +38,7 @@ def get_config(strategy, **kwargs):
|
||||
|
||||
def assert_equal(model, config, gt_name, fx=1, fy=1, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"):
|
||||
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
||||
print(f"Input image shape: {img.shape}")
|
||||
res = model(img, mask, config)
|
||||
cv2.imwrite(
|
||||
str(save_dir / gt_name),
|
||||
@ -245,6 +244,7 @@ def test_runway_sd_1_5(strategy, sampler):
|
||||
f"runway_sd_{strategy.capitalize()}_{sampler}_result.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||
fx=1.3
|
||||
)
|
||||
|
||||
assert_equal(
|
||||
@ -253,6 +253,7 @@ def test_runway_sd_1_5(strategy, sampler):
|
||||
f"runway_sd_{strategy.capitalize()}_{sampler}_blur_mask_result.png",
|
||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
|
||||
fy=1.3
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user