diff --git a/iopaint/download.py b/iopaint/download.py index f8fc7c9..2f7ceef 100644 --- a/iopaint/download.py +++ b/iopaint/download.py @@ -1,8 +1,8 @@ import json import os +from functools import lru_cache from typing import List -from huggingface_hub.constants import HF_HUB_CACHE from loguru import logger from pathlib import Path @@ -41,39 +41,119 @@ def folder_name_to_show_name(name: str) -> str: return name.replace("models--", "").replace("--", "/") +@lru_cache(maxsize=512) +def get_sd_model_type(model_abs_path: str) -> ModelType: + if "inpaint" in Path(model_abs_path).name.lower(): + model_type = ModelType.DIFFUSERS_SD_INPAINT + else: + # load once to check num_in_channels + from diffusers import StableDiffusionInpaintPipeline + + try: + StableDiffusionInpaintPipeline.from_single_file( + model_abs_path, + load_safety_checker=False, + local_files_only=True, + num_in_channels=9, + ) + model_type = ModelType.DIFFUSERS_SD_INPAINT + except ValueError as e: + if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e): + model_type = ModelType.DIFFUSERS_SD + else: + raise e + return model_type + + +@lru_cache() +def get_sdxl_model_type(model_abs_path: str) -> ModelType: + if "inpaint" in model_abs_path: + model_type = ModelType.DIFFUSERS_SDXL_INPAINT + else: + # load once to check num_in_channels + from diffusers import StableDiffusionXLInpaintPipeline + + try: + model = StableDiffusionXLInpaintPipeline.from_single_file( + model_abs_path, + load_safety_checker=False, + local_files_only=True, + num_in_channels=9, + ) + if model.unet.config.in_channels == 9: + # https://github.com/huggingface/diffusers/issues/6610 + model_type = ModelType.DIFFUSERS_SDXL_INPAINT + else: + model_type = ModelType.DIFFUSERS_SDXL + except ValueError as e: + if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e): + model_type = ModelType.DIFFUSERS_SDXL + else: + raise e + return model_type + + def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]: cache_dir = Path(cache_dir) stable_diffusion_dir = cache_dir / "stable_diffusion" - stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl" - # logger.info(f"Scanning single file sd/sdxl models in {cache_dir}") + cache_file = stable_diffusion_dir / "iopaint_cache.json" + model_type_cache = {} + if cache_file.exists(): + try: + with open(cache_file, "r", encoding="utf-8") as f: + model_type_cache = json.load(f) + assert isinstance(model_type_cache, dict) + except: + pass + res = [] for it in stable_diffusion_dir.glob(f"*.*"): if it.suffix not in [".safetensors", ".ckpt"]: continue - if "inpaint" in str(it).lower(): - model_type = ModelType.DIFFUSERS_SD_INPAINT - else: - model_type = ModelType.DIFFUSERS_SD + model_abs_path = str(it.absolute()) + model_type = model_type_cache.get(it.name) + if model_type is None: + model_type = get_sd_model_type(model_abs_path) + model_type_cache[it.name] = model_type res.append( ModelInfo( name=it.name, - path=str(it.absolute()), + path=model_abs_path, model_type=model_type, is_single_file_diffusers=True, ) ) + if stable_diffusion_dir.exists(): + with open(cache_file, "w", encoding="utf-8") as fw: + json.dump(model_type_cache, fw, indent=2, ensure_ascii=False) + + stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl" + sdxl_cache_file = stable_diffusion_xl_dir / "iopaint_cache.json" + sdxl_model_type_cache = {} + if sdxl_cache_file.exists(): + try: + with open(sdxl_cache_file, "r", encoding="utf-8") as f: + sdxl_model_type_cache = json.load(f) + assert isinstance(sdxl_model_type_cache, dict) + except: + pass for it in stable_diffusion_xl_dir.glob(f"*.*"): if it.suffix not in [".safetensors", ".ckpt"]: continue - if "inpaint" in str(it).lower(): - model_type = ModelType.DIFFUSERS_SDXL_INPAINT - else: - model_type = ModelType.DIFFUSERS_SDXL + model_abs_path = str(it.absolute()) + model_type = sdxl_model_type_cache.get(it.name) + if model_type is None: + model_type = get_sdxl_model_type(model_abs_path) + sdxl_model_type_cache[it.name] = model_type + if stable_diffusion_xl_dir.exists(): + with open(sdxl_cache_file, "w", encoding="utf-8") as fw: + json.dump(sdxl_model_type_cache, fw, indent=2, ensure_ascii=False) + res.append( ModelInfo( name=it.name, - path=str(it.absolute()), + path=model_abs_path, model_type=model_type, is_single_file_diffusers=True, ) @@ -100,6 +180,8 @@ def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]: def scan_models() -> List[ModelInfo]: + from huggingface_hub.constants import HF_HUB_CACHE + model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR) available_models = [] available_models.extend(scan_inpaint_models(model_dir))