backend: Better distinguish between the normal and inpaint models of stable diffusion.
This commit is contained in:
parent
7c6e62e164
commit
e811481e78
@ -1,8 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import List
|
||||
|
||||
from huggingface_hub.constants import HF_HUB_CACHE
|
||||
from loguru import logger
|
||||
from pathlib import Path
|
||||
|
||||
@ -41,39 +41,119 @@ def folder_name_to_show_name(name: str) -> str:
|
||||
return name.replace("models--", "").replace("--", "/")
|
||||
|
||||
|
||||
@lru_cache(maxsize=512)
|
||||
def get_sd_model_type(model_abs_path: str) -> ModelType:
|
||||
if "inpaint" in Path(model_abs_path).name.lower():
|
||||
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
||||
else:
|
||||
# load once to check num_in_channels
|
||||
from diffusers import StableDiffusionInpaintPipeline
|
||||
|
||||
try:
|
||||
StableDiffusionInpaintPipeline.from_single_file(
|
||||
model_abs_path,
|
||||
load_safety_checker=False,
|
||||
local_files_only=True,
|
||||
num_in_channels=9,
|
||||
)
|
||||
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
||||
except ValueError as e:
|
||||
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
|
||||
model_type = ModelType.DIFFUSERS_SD
|
||||
else:
|
||||
raise e
|
||||
return model_type
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_sdxl_model_type(model_abs_path: str) -> ModelType:
|
||||
if "inpaint" in model_abs_path:
|
||||
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
||||
else:
|
||||
# load once to check num_in_channels
|
||||
from diffusers import StableDiffusionXLInpaintPipeline
|
||||
|
||||
try:
|
||||
model = StableDiffusionXLInpaintPipeline.from_single_file(
|
||||
model_abs_path,
|
||||
load_safety_checker=False,
|
||||
local_files_only=True,
|
||||
num_in_channels=9,
|
||||
)
|
||||
if model.unet.config.in_channels == 9:
|
||||
# https://github.com/huggingface/diffusers/issues/6610
|
||||
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
||||
else:
|
||||
model_type = ModelType.DIFFUSERS_SDXL
|
||||
except ValueError as e:
|
||||
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
|
||||
model_type = ModelType.DIFFUSERS_SDXL
|
||||
else:
|
||||
raise e
|
||||
return model_type
|
||||
|
||||
|
||||
def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
|
||||
cache_dir = Path(cache_dir)
|
||||
stable_diffusion_dir = cache_dir / "stable_diffusion"
|
||||
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
|
||||
# logger.info(f"Scanning single file sd/sdxl models in {cache_dir}")
|
||||
cache_file = stable_diffusion_dir / "iopaint_cache.json"
|
||||
model_type_cache = {}
|
||||
if cache_file.exists():
|
||||
try:
|
||||
with open(cache_file, "r", encoding="utf-8") as f:
|
||||
model_type_cache = json.load(f)
|
||||
assert isinstance(model_type_cache, dict)
|
||||
except:
|
||||
pass
|
||||
|
||||
res = []
|
||||
for it in stable_diffusion_dir.glob(f"*.*"):
|
||||
if it.suffix not in [".safetensors", ".ckpt"]:
|
||||
continue
|
||||
if "inpaint" in str(it).lower():
|
||||
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
||||
else:
|
||||
model_type = ModelType.DIFFUSERS_SD
|
||||
model_abs_path = str(it.absolute())
|
||||
model_type = model_type_cache.get(it.name)
|
||||
if model_type is None:
|
||||
model_type = get_sd_model_type(model_abs_path)
|
||||
model_type_cache[it.name] = model_type
|
||||
res.append(
|
||||
ModelInfo(
|
||||
name=it.name,
|
||||
path=str(it.absolute()),
|
||||
path=model_abs_path,
|
||||
model_type=model_type,
|
||||
is_single_file_diffusers=True,
|
||||
)
|
||||
)
|
||||
if stable_diffusion_dir.exists():
|
||||
with open(cache_file, "w", encoding="utf-8") as fw:
|
||||
json.dump(model_type_cache, fw, indent=2, ensure_ascii=False)
|
||||
|
||||
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
|
||||
sdxl_cache_file = stable_diffusion_xl_dir / "iopaint_cache.json"
|
||||
sdxl_model_type_cache = {}
|
||||
if sdxl_cache_file.exists():
|
||||
try:
|
||||
with open(sdxl_cache_file, "r", encoding="utf-8") as f:
|
||||
sdxl_model_type_cache = json.load(f)
|
||||
assert isinstance(sdxl_model_type_cache, dict)
|
||||
except:
|
||||
pass
|
||||
|
||||
for it in stable_diffusion_xl_dir.glob(f"*.*"):
|
||||
if it.suffix not in [".safetensors", ".ckpt"]:
|
||||
continue
|
||||
if "inpaint" in str(it).lower():
|
||||
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
||||
else:
|
||||
model_type = ModelType.DIFFUSERS_SDXL
|
||||
model_abs_path = str(it.absolute())
|
||||
model_type = sdxl_model_type_cache.get(it.name)
|
||||
if model_type is None:
|
||||
model_type = get_sdxl_model_type(model_abs_path)
|
||||
sdxl_model_type_cache[it.name] = model_type
|
||||
if stable_diffusion_xl_dir.exists():
|
||||
with open(sdxl_cache_file, "w", encoding="utf-8") as fw:
|
||||
json.dump(sdxl_model_type_cache, fw, indent=2, ensure_ascii=False)
|
||||
|
||||
res.append(
|
||||
ModelInfo(
|
||||
name=it.name,
|
||||
path=str(it.absolute()),
|
||||
path=model_abs_path,
|
||||
model_type=model_type,
|
||||
is_single_file_diffusers=True,
|
||||
)
|
||||
@ -100,6 +180,8 @@ def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]:
|
||||
|
||||
|
||||
def scan_models() -> List[ModelInfo]:
|
||||
from huggingface_hub.constants import HF_HUB_CACHE
|
||||
|
||||
model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR)
|
||||
available_models = []
|
||||
available_models.extend(scan_inpaint_models(model_dir))
|
||||
|
Loading…
Reference in New Issue
Block a user