diff --git a/iopaint/download.py b/iopaint/download.py index a587d90..18e3054 100644 --- a/iopaint/download.py +++ b/iopaint/download.py @@ -68,7 +68,7 @@ def get_sd_model_type(model_abs_path: str) -> ModelType: if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e): model_type = ModelType.DIFFUSERS_SD else: - raise e + logger.info(f"Ignore non sd or sdxl file: {model_abs_path}") return model_type @@ -96,7 +96,7 @@ def get_sdxl_model_type(model_abs_path: str) -> ModelType: if "but got torch.Size([320, 4, 3, 3])" in str(e): model_type = ModelType.DIFFUSERS_SDXL else: - raise e + logger.info(f"Ignore non sd or sdxl file: {model_abs_path}") return model_type diff --git a/iopaint/model/brushnet/brushnet_wrapper.py b/iopaint/model/brushnet/brushnet_wrapper.py index 69bd484..c7343d2 100644 --- a/iopaint/model/brushnet/brushnet_wrapper.py +++ b/iopaint/model/brushnet/brushnet_wrapper.py @@ -64,7 +64,7 @@ class BrushNetWrapper(DiffusionInpaintModel): self.model_id_or_path, torch_dtype=torch_dtype, load_safety_checker=not disable_nsfw_checker, - config_files=get_config_files(), + original_config_file=get_config_files()['v1'], brushnet=brushnet, **model_kwargs, ) diff --git a/iopaint/tests/test_brushnet.py b/iopaint/tests/test_brushnet.py index e46d194..0261b1d 100644 --- a/iopaint/tests/test_brushnet.py +++ b/iopaint/tests/test_brushnet.py @@ -42,7 +42,7 @@ def test_runway_brushnet(device, sampler): assert_equal( model, cfg, - f"brushnet_runway_1_5_freeu_device_{device}.png", + f"brushnet_random_mask_{device}.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", ) @@ -74,16 +74,13 @@ def test_brushnet_local_file_path(device, sampler, name): brushnet_method=SD_BRUSHNET_CHOICES[1] ) cfg.sd_sampler = sampler - name = f"device_{device}_{sampler}_{name}" - - is_sdxl = "sd_xl" in name assert_equal( model, cfg, - f"brushnet_sd_local_model_{name}.png", + f"brushnet_segmentation_mask_{device}.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", - fx=1.5 if is_sdxl else 1, - fy=1.5 if is_sdxl else 1, + fx=1, + fy=1, )