From 7c1f83e71ddb5b0ef68cca87a8f8bb9ad1b8b1eb Mon Sep 17 00:00:00 2001 From: Qing Date: Mon, 8 Jan 2024 23:38:18 +0800 Subject: [PATCH] fix match_histograms --- iopaint/model/base.py | 3 +++ iopaint/tests/test_match_histograms.py | 36 ++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) create mode 100644 iopaint/tests/test_match_histograms.py diff --git a/iopaint/model/base.py b/iopaint/model/base.py index 839fd94..f455481 100644 --- a/iopaint/model/base.py +++ b/iopaint/model/base.py @@ -211,6 +211,9 @@ class InpaintModel: def _match_histograms(self, source, reference, mask): transformed_channels = [] + if len(mask.shape) == 3: + mask = mask[:, :, -1] + for channel in range(source.shape[-1]): source_channel = source[:, :, channel] reference_channel = reference[:, :, channel] diff --git a/iopaint/tests/test_match_histograms.py b/iopaint/tests/test_match_histograms.py new file mode 100644 index 0000000..c20a283 --- /dev/null +++ b/iopaint/tests/test_match_histograms.py @@ -0,0 +1,36 @@ +import pytest +import torch + +from iopaint.model_manager import ModelManager +from iopaint.schema import SDSampler, HDStrategy +from iopaint.tests.utils import check_device, get_config, assert_equal, current_dir + + +@pytest.mark.parametrize("device", ["cuda", "mps"]) +@pytest.mark.parametrize("sampler", [SDSampler.ddim]) +def test_sd_match_histograms(device, sampler): + sd_steps = check_device(device) + + model = ModelManager( + name="runwayml/stable-diffusion-inpainting", + device=torch.device(device), + disable_nsfw=True, + sd_cpu_textencoder=False, + ) + cfg = get_config( + strategy=HDStrategy.ORIGINAL, + prompt="face of a fox, sitting on a bench", + sd_steps=sd_steps, + sd_guidance_scale=7.5, + sd_lcm_lora=False, + sd_match_histograms=True, + sd_sampler=sampler + ) + + assert_equal( + model, + cfg, + f"runway_sd_1_5_device_{device}_match_histograms.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + )