backend: Better distinguish between the normal and inpaint models of stable diffusion.

This commit is contained in:
Qing 2024-01-18 20:54:10 +08:00
parent 7c6e62e164
commit e811481e78

View File

@ -1,8 +1,8 @@
import json import json
import os import os
from functools import lru_cache
from typing import List from typing import List
from huggingface_hub.constants import HF_HUB_CACHE
from loguru import logger from loguru import logger
from pathlib import Path from pathlib import Path
@ -41,39 +41,119 @@ def folder_name_to_show_name(name: str) -> str:
return name.replace("models--", "").replace("--", "/") return name.replace("models--", "").replace("--", "/")
@lru_cache(maxsize=512)
def get_sd_model_type(model_abs_path: str) -> ModelType:
if "inpaint" in Path(model_abs_path).name.lower():
model_type = ModelType.DIFFUSERS_SD_INPAINT
else:
# load once to check num_in_channels
from diffusers import StableDiffusionInpaintPipeline
try:
StableDiffusionInpaintPipeline.from_single_file(
model_abs_path,
load_safety_checker=False,
local_files_only=True,
num_in_channels=9,
)
model_type = ModelType.DIFFUSERS_SD_INPAINT
except ValueError as e:
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
model_type = ModelType.DIFFUSERS_SD
else:
raise e
return model_type
@lru_cache()
def get_sdxl_model_type(model_abs_path: str) -> ModelType:
if "inpaint" in model_abs_path:
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
else:
# load once to check num_in_channels
from diffusers import StableDiffusionXLInpaintPipeline
try:
model = StableDiffusionXLInpaintPipeline.from_single_file(
model_abs_path,
load_safety_checker=False,
local_files_only=True,
num_in_channels=9,
)
if model.unet.config.in_channels == 9:
# https://github.com/huggingface/diffusers/issues/6610
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
else:
model_type = ModelType.DIFFUSERS_SDXL
except ValueError as e:
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
model_type = ModelType.DIFFUSERS_SDXL
else:
raise e
return model_type
def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]: def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
cache_dir = Path(cache_dir) cache_dir = Path(cache_dir)
stable_diffusion_dir = cache_dir / "stable_diffusion" stable_diffusion_dir = cache_dir / "stable_diffusion"
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl" cache_file = stable_diffusion_dir / "iopaint_cache.json"
# logger.info(f"Scanning single file sd/sdxl models in {cache_dir}") model_type_cache = {}
if cache_file.exists():
try:
with open(cache_file, "r", encoding="utf-8") as f:
model_type_cache = json.load(f)
assert isinstance(model_type_cache, dict)
except:
pass
res = [] res = []
for it in stable_diffusion_dir.glob(f"*.*"): for it in stable_diffusion_dir.glob(f"*.*"):
if it.suffix not in [".safetensors", ".ckpt"]: if it.suffix not in [".safetensors", ".ckpt"]:
continue continue
if "inpaint" in str(it).lower(): model_abs_path = str(it.absolute())
model_type = ModelType.DIFFUSERS_SD_INPAINT model_type = model_type_cache.get(it.name)
else: if model_type is None:
model_type = ModelType.DIFFUSERS_SD model_type = get_sd_model_type(model_abs_path)
model_type_cache[it.name] = model_type
res.append( res.append(
ModelInfo( ModelInfo(
name=it.name, name=it.name,
path=str(it.absolute()), path=model_abs_path,
model_type=model_type, model_type=model_type,
is_single_file_diffusers=True, is_single_file_diffusers=True,
) )
) )
if stable_diffusion_dir.exists():
with open(cache_file, "w", encoding="utf-8") as fw:
json.dump(model_type_cache, fw, indent=2, ensure_ascii=False)
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
sdxl_cache_file = stable_diffusion_xl_dir / "iopaint_cache.json"
sdxl_model_type_cache = {}
if sdxl_cache_file.exists():
try:
with open(sdxl_cache_file, "r", encoding="utf-8") as f:
sdxl_model_type_cache = json.load(f)
assert isinstance(sdxl_model_type_cache, dict)
except:
pass
for it in stable_diffusion_xl_dir.glob(f"*.*"): for it in stable_diffusion_xl_dir.glob(f"*.*"):
if it.suffix not in [".safetensors", ".ckpt"]: if it.suffix not in [".safetensors", ".ckpt"]:
continue continue
if "inpaint" in str(it).lower(): model_abs_path = str(it.absolute())
model_type = ModelType.DIFFUSERS_SDXL_INPAINT model_type = sdxl_model_type_cache.get(it.name)
else: if model_type is None:
model_type = ModelType.DIFFUSERS_SDXL model_type = get_sdxl_model_type(model_abs_path)
sdxl_model_type_cache[it.name] = model_type
if stable_diffusion_xl_dir.exists():
with open(sdxl_cache_file, "w", encoding="utf-8") as fw:
json.dump(sdxl_model_type_cache, fw, indent=2, ensure_ascii=False)
res.append( res.append(
ModelInfo( ModelInfo(
name=it.name, name=it.name,
path=str(it.absolute()), path=model_abs_path,
model_type=model_type, model_type=model_type,
is_single_file_diffusers=True, is_single_file_diffusers=True,
) )
@ -100,6 +180,8 @@ def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]:
def scan_models() -> List[ModelInfo]: def scan_models() -> List[ModelInfo]:
from huggingface_hub.constants import HF_HUB_CACHE
model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR) model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR)
available_models = [] available_models = []
available_models.extend(scan_inpaint_models(model_dir)) available_models.extend(scan_inpaint_models(model_dir))