update test
This commit is contained in:
parent
13b6371a53
commit
d4bd37682a
@ -18,8 +18,6 @@ def get_data(fx=1, fy=1.0, img_p=current_dir / "image.png", mask_p=current_dir /
|
|||||||
img = cv2.imread(str(img_p))
|
img = cv2.imread(str(img_p))
|
||||||
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
|
||||||
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
|
mask = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
|
||||||
|
|
||||||
if fx != 1:
|
|
||||||
img = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
|
img = cv2.resize(img, None, fx=fx, fy=fy, interpolation=cv2.INTER_AREA)
|
||||||
mask = cv2.resize(mask, None, fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST)
|
mask = cv2.resize(mask, None, fx=fx, fy=fy, interpolation=cv2.INTER_NEAREST)
|
||||||
return img, mask
|
return img, mask
|
||||||
@ -40,6 +38,7 @@ def get_config(strategy, **kwargs):
|
|||||||
|
|
||||||
def assert_equal(model, config, gt_name, fx=1, fy=1, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"):
|
def assert_equal(model, config, gt_name, fx=1, fy=1, img_p=current_dir / "image.png", mask_p=current_dir / "mask.png"):
|
||||||
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
img, mask = get_data(fx=fx, fy=fy, img_p=img_p, mask_p=mask_p)
|
||||||
|
print(f"Input image shape: {img.shape}")
|
||||||
res = model(img, mask, config)
|
res = model(img, mask, config)
|
||||||
cv2.imwrite(
|
cv2.imwrite(
|
||||||
str(save_dir / gt_name),
|
str(save_dir / gt_name),
|
||||||
@ -245,6 +244,7 @@ def test_runway_sd_1_5(strategy, sampler):
|
|||||||
f"runway_sd_{strategy.capitalize()}_{sampler}_result.png",
|
f"runway_sd_{strategy.capitalize()}_{sampler}_result.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png",
|
||||||
|
fx=1.3
|
||||||
)
|
)
|
||||||
|
|
||||||
assert_equal(
|
assert_equal(
|
||||||
@ -253,6 +253,7 @@ def test_runway_sd_1_5(strategy, sampler):
|
|||||||
f"runway_sd_{strategy.capitalize()}_{sampler}_blur_mask_result.png",
|
f"runway_sd_{strategy.capitalize()}_{sampler}_blur_mask_result.png",
|
||||||
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
img_p=current_dir / "overture-creations-5sI6fQgYIuo.png",
|
||||||
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
|
mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask_blur.png",
|
||||||
|
fy=1.3
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user