backend: Better distinguish between the normal and inpaint models of stable diffusion.
This commit is contained in:
parent
7c6e62e164
commit
e811481e78
@ -1,8 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from functools import lru_cache
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from huggingface_hub.constants import HF_HUB_CACHE
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@ -41,39 +41,119 @@ def folder_name_to_show_name(name: str) -> str:
|
|||||||
return name.replace("models--", "").replace("--", "/")
|
return name.replace("models--", "").replace("--", "/")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=512)
|
||||||
|
def get_sd_model_type(model_abs_path: str) -> ModelType:
|
||||||
|
if "inpaint" in Path(model_abs_path).name.lower():
|
||||||
|
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
||||||
|
else:
|
||||||
|
# load once to check num_in_channels
|
||||||
|
from diffusers import StableDiffusionInpaintPipeline
|
||||||
|
|
||||||
|
try:
|
||||||
|
StableDiffusionInpaintPipeline.from_single_file(
|
||||||
|
model_abs_path,
|
||||||
|
load_safety_checker=False,
|
||||||
|
local_files_only=True,
|
||||||
|
num_in_channels=9,
|
||||||
|
)
|
||||||
|
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
||||||
|
except ValueError as e:
|
||||||
|
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
|
||||||
|
model_type = ModelType.DIFFUSERS_SD
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
return model_type
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def get_sdxl_model_type(model_abs_path: str) -> ModelType:
|
||||||
|
if "inpaint" in model_abs_path:
|
||||||
|
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
||||||
|
else:
|
||||||
|
# load once to check num_in_channels
|
||||||
|
from diffusers import StableDiffusionXLInpaintPipeline
|
||||||
|
|
||||||
|
try:
|
||||||
|
model = StableDiffusionXLInpaintPipeline.from_single_file(
|
||||||
|
model_abs_path,
|
||||||
|
load_safety_checker=False,
|
||||||
|
local_files_only=True,
|
||||||
|
num_in_channels=9,
|
||||||
|
)
|
||||||
|
if model.unet.config.in_channels == 9:
|
||||||
|
# https://github.com/huggingface/diffusers/issues/6610
|
||||||
|
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
||||||
|
else:
|
||||||
|
model_type = ModelType.DIFFUSERS_SDXL
|
||||||
|
except ValueError as e:
|
||||||
|
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
|
||||||
|
model_type = ModelType.DIFFUSERS_SDXL
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
return model_type
|
||||||
|
|
||||||
|
|
||||||
def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
|
def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
|
||||||
cache_dir = Path(cache_dir)
|
cache_dir = Path(cache_dir)
|
||||||
stable_diffusion_dir = cache_dir / "stable_diffusion"
|
stable_diffusion_dir = cache_dir / "stable_diffusion"
|
||||||
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
|
cache_file = stable_diffusion_dir / "iopaint_cache.json"
|
||||||
# logger.info(f"Scanning single file sd/sdxl models in {cache_dir}")
|
model_type_cache = {}
|
||||||
|
if cache_file.exists():
|
||||||
|
try:
|
||||||
|
with open(cache_file, "r", encoding="utf-8") as f:
|
||||||
|
model_type_cache = json.load(f)
|
||||||
|
assert isinstance(model_type_cache, dict)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
res = []
|
res = []
|
||||||
for it in stable_diffusion_dir.glob(f"*.*"):
|
for it in stable_diffusion_dir.glob(f"*.*"):
|
||||||
if it.suffix not in [".safetensors", ".ckpt"]:
|
if it.suffix not in [".safetensors", ".ckpt"]:
|
||||||
continue
|
continue
|
||||||
if "inpaint" in str(it).lower():
|
model_abs_path = str(it.absolute())
|
||||||
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
model_type = model_type_cache.get(it.name)
|
||||||
else:
|
if model_type is None:
|
||||||
model_type = ModelType.DIFFUSERS_SD
|
model_type = get_sd_model_type(model_abs_path)
|
||||||
|
model_type_cache[it.name] = model_type
|
||||||
res.append(
|
res.append(
|
||||||
ModelInfo(
|
ModelInfo(
|
||||||
name=it.name,
|
name=it.name,
|
||||||
path=str(it.absolute()),
|
path=model_abs_path,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
is_single_file_diffusers=True,
|
is_single_file_diffusers=True,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
if stable_diffusion_dir.exists():
|
||||||
|
with open(cache_file, "w", encoding="utf-8") as fw:
|
||||||
|
json.dump(model_type_cache, fw, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
|
||||||
|
sdxl_cache_file = stable_diffusion_xl_dir / "iopaint_cache.json"
|
||||||
|
sdxl_model_type_cache = {}
|
||||||
|
if sdxl_cache_file.exists():
|
||||||
|
try:
|
||||||
|
with open(sdxl_cache_file, "r", encoding="utf-8") as f:
|
||||||
|
sdxl_model_type_cache = json.load(f)
|
||||||
|
assert isinstance(sdxl_model_type_cache, dict)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
for it in stable_diffusion_xl_dir.glob(f"*.*"):
|
for it in stable_diffusion_xl_dir.glob(f"*.*"):
|
||||||
if it.suffix not in [".safetensors", ".ckpt"]:
|
if it.suffix not in [".safetensors", ".ckpt"]:
|
||||||
continue
|
continue
|
||||||
if "inpaint" in str(it).lower():
|
model_abs_path = str(it.absolute())
|
||||||
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
model_type = sdxl_model_type_cache.get(it.name)
|
||||||
else:
|
if model_type is None:
|
||||||
model_type = ModelType.DIFFUSERS_SDXL
|
model_type = get_sdxl_model_type(model_abs_path)
|
||||||
|
sdxl_model_type_cache[it.name] = model_type
|
||||||
|
if stable_diffusion_xl_dir.exists():
|
||||||
|
with open(sdxl_cache_file, "w", encoding="utf-8") as fw:
|
||||||
|
json.dump(sdxl_model_type_cache, fw, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
res.append(
|
res.append(
|
||||||
ModelInfo(
|
ModelInfo(
|
||||||
name=it.name,
|
name=it.name,
|
||||||
path=str(it.absolute()),
|
path=model_abs_path,
|
||||||
model_type=model_type,
|
model_type=model_type,
|
||||||
is_single_file_diffusers=True,
|
is_single_file_diffusers=True,
|
||||||
)
|
)
|
||||||
@ -100,6 +180,8 @@ def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]:
|
|||||||
|
|
||||||
|
|
||||||
def scan_models() -> List[ModelInfo]:
|
def scan_models() -> List[ModelInfo]:
|
||||||
|
from huggingface_hub.constants import HF_HUB_CACHE
|
||||||
|
|
||||||
model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR)
|
model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR)
|
||||||
available_models = []
|
available_models = []
|
||||||
available_models.extend(scan_inpaint_models(model_dir))
|
available_models.extend(scan_inpaint_models(model_dir))
|
||||||
|
Loading…
Reference in New Issue
Block a user