diff --git a/iopaint/download.py b/iopaint/download.py index 2ebd7fc..51fd84f 100644 --- a/iopaint/download.py +++ b/iopaint/download.py @@ -25,11 +25,11 @@ def cli_download_model(model: str): if model in models and models[model].is_erase_model: logger.info(f"Downloading {model}...") models[model].download() - logger.info(f"Done.") + logger.info("Done.") elif model == ANYTEXT_NAME: logger.info(f"Downloading {model}...") models[model].download() - logger.info(f"Done.") + logger.info("Done.") else: logger.info(f"Downloading model from Huggingface: {model}") from diffusers import DiffusionPipeline @@ -60,7 +60,7 @@ def get_sd_model_type(model_abs_path: str) -> ModelType: model_abs_path, load_safety_checker=False, num_in_channels=9, - config_files=get_config_files(), + original_config_file=get_config_files()['v1'] ) model_type = ModelType.DIFFUSERS_SD_INPAINT except ValueError as e: @@ -84,7 +84,7 @@ def get_sdxl_model_type(model_abs_path: str) -> ModelType: model_abs_path, load_safety_checker=False, num_in_channels=9, - config_files=get_config_files(), + original_config_file=get_config_files()['xl'], ) if model.unet.config.in_channels == 9: # https://github.com/huggingface/diffusers/issues/6610 @@ -113,7 +113,7 @@ def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]: pass res = [] - for it in stable_diffusion_dir.glob(f"*.*"): + for it in stable_diffusion_dir.glob("*.*"): if it.suffix not in [".safetensors", ".ckpt"]: continue model_abs_path = str(it.absolute()) @@ -144,7 +144,7 @@ def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]: except: pass - for it in stable_diffusion_xl_dir.glob(f"*.*"): + for it in stable_diffusion_xl_dir.glob("*.*"): if it.suffix not in [".safetensors", ".ckpt"]: continue model_abs_path = str(it.absolute()) diff --git a/iopaint/model/controlnet.py b/iopaint/model/controlnet.py index d52db01..7b4d243 100644 --- a/iopaint/model/controlnet.py +++ b/iopaint/model/controlnet.py @@ -71,6 +71,7 @@ class ControlNet(DiffusionInpaintModel): use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) self.torch_dtype = torch_dtype + original_config_file_name = "v1" if model_info.model_type in [ ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD_INPAINT, @@ -78,6 +79,8 @@ class ControlNet(DiffusionInpaintModel): from diffusers import ( StableDiffusionControlNetInpaintPipeline as PipeClass, ) + original_config_file_name = "v1" + elif model_info.model_type in [ ModelType.DIFFUSERS_SDXL, ModelType.DIFFUSERS_SDXL_INPAINT, @@ -85,6 +88,7 @@ class ControlNet(DiffusionInpaintModel): from diffusers import ( StableDiffusionXLControlNetInpaintPipeline as PipeClass, ) + original_config_file_name = "xl" controlnet = ControlNetModel.from_pretrained( pretrained_model_name_or_path=controlnet_method, @@ -103,7 +107,7 @@ class ControlNet(DiffusionInpaintModel): controlnet=controlnet, load_safety_checker=not disable_nsfw_checker, torch_dtype=torch_dtype, - config_files=get_config_files(), + original_config_file=get_config_files()[original_config_file_name], **model_kwargs, ) else: diff --git a/iopaint/model/sd.py b/iopaint/model/sd.py index 8f42fff..2f6698c 100644 --- a/iopaint/model/sd.py +++ b/iopaint/model/sd.py @@ -52,7 +52,7 @@ class SD(DiffusionInpaintModel): self.model_id_or_path, torch_dtype=torch_dtype, load_safety_checker=not disable_nsfw_checker, - config_files=get_config_files(), + original_config_file=get_config_files()['v1'], **model_kwargs, ) else: diff --git a/iopaint/model/sdxl.py b/iopaint/model/sdxl.py index 29312b1..e6d66a8 100644 --- a/iopaint/model/sdxl.py +++ b/iopaint/model/sdxl.py @@ -42,7 +42,7 @@ class SDXL(DiffusionInpaintModel): torch_dtype=torch_dtype, num_in_channels=num_in_channels, load_safety_checker=False, - config_files=get_config_files() + original_config_file=get_config_files()['xl'], ) else: model_kwargs = {