2023-12-01 03:15:35 +01:00
|
|
|
import json
|
2023-11-16 14:12:06 +01:00
|
|
|
import os
|
2024-01-18 13:54:10 +01:00
|
|
|
from functools import lru_cache
|
2023-12-01 03:15:35 +01:00
|
|
|
from typing import List
|
2023-11-16 14:12:06 +01:00
|
|
|
|
|
|
|
from loguru import logger
|
|
|
|
from pathlib import Path
|
|
|
|
|
2024-01-05 08:19:23 +01:00
|
|
|
from iopaint.const import (
|
2023-12-27 15:00:07 +01:00
|
|
|
DEFAULT_MODEL_DIR,
|
2023-12-01 03:15:35 +01:00
|
|
|
DIFFUSERS_SD_CLASS_NAME,
|
2023-12-27 15:00:07 +01:00
|
|
|
DIFFUSERS_SD_INPAINT_CLASS_NAME,
|
2023-12-01 03:15:35 +01:00
|
|
|
DIFFUSERS_SDXL_CLASS_NAME,
|
2023-12-27 15:00:07 +01:00
|
|
|
DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
|
2024-01-21 16:25:50 +01:00
|
|
|
ANYTEXT_NAME,
|
2023-12-01 03:15:35 +01:00
|
|
|
)
|
2024-01-30 06:19:13 +01:00
|
|
|
from iopaint.model.original_sd_configs import get_config_files
|
2024-01-05 08:19:23 +01:00
|
|
|
from iopaint.model_info import ModelInfo, ModelType
|
2023-12-01 03:15:35 +01:00
|
|
|
|
2023-11-16 14:12:06 +01:00
|
|
|
|
2024-01-05 08:38:34 +01:00
|
|
|
def cli_download_model(model: str):
|
2024-01-05 08:19:23 +01:00
|
|
|
from iopaint.model import models
|
2024-01-05 09:38:55 +01:00
|
|
|
from iopaint.model.utils import handle_from_pretrained_exceptions
|
2023-11-16 14:12:06 +01:00
|
|
|
|
2023-12-27 15:00:07 +01:00
|
|
|
if model in models and models[model].is_erase_model:
|
2023-11-16 14:12:06 +01:00
|
|
|
logger.info(f"Downloading {model}...")
|
|
|
|
models[model].download()
|
|
|
|
logger.info(f"Done.")
|
2024-01-21 16:25:50 +01:00
|
|
|
elif model == ANYTEXT_NAME:
|
|
|
|
logger.info(f"Downloading {model}...")
|
|
|
|
models[model].download()
|
|
|
|
logger.info(f"Done.")
|
2023-11-16 14:12:06 +01:00
|
|
|
else:
|
|
|
|
logger.info(f"Downloading model from Huggingface: {model}")
|
2023-12-01 03:15:35 +01:00
|
|
|
from diffusers import DiffusionPipeline
|
|
|
|
|
2023-12-27 15:00:07 +01:00
|
|
|
downloaded_path = handle_from_pretrained_exceptions(
|
|
|
|
DiffusionPipeline.download,
|
2023-12-01 03:15:35 +01:00
|
|
|
pretrained_model_name=model,
|
2023-12-27 15:00:07 +01:00
|
|
|
variant="fp16",
|
2023-12-01 03:15:35 +01:00
|
|
|
resume_download=True,
|
|
|
|
)
|
|
|
|
logger.info(f"Done. Downloaded to {downloaded_path}")
|
|
|
|
|
|
|
|
|
|
|
|
def folder_name_to_show_name(name: str) -> str:
|
|
|
|
return name.replace("models--", "").replace("--", "/")
|
|
|
|
|
|
|
|
|
2024-01-18 13:54:10 +01:00
|
|
|
@lru_cache(maxsize=512)
|
|
|
|
def get_sd_model_type(model_abs_path: str) -> ModelType:
|
|
|
|
if "inpaint" in Path(model_abs_path).name.lower():
|
|
|
|
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
|
|
|
else:
|
|
|
|
# load once to check num_in_channels
|
|
|
|
from diffusers import StableDiffusionInpaintPipeline
|
|
|
|
|
|
|
|
try:
|
|
|
|
StableDiffusionInpaintPipeline.from_single_file(
|
|
|
|
model_abs_path,
|
|
|
|
load_safety_checker=False,
|
|
|
|
local_files_only=True,
|
|
|
|
num_in_channels=9,
|
2024-02-01 02:01:43 +01:00
|
|
|
config_files=get_config_files(),
|
2024-01-18 13:54:10 +01:00
|
|
|
)
|
|
|
|
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
|
|
|
except ValueError as e:
|
|
|
|
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
|
|
|
|
model_type = ModelType.DIFFUSERS_SD
|
|
|
|
else:
|
|
|
|
raise e
|
|
|
|
return model_type
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache()
|
|
|
|
def get_sdxl_model_type(model_abs_path: str) -> ModelType:
|
|
|
|
if "inpaint" in model_abs_path:
|
|
|
|
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
|
|
|
else:
|
|
|
|
# load once to check num_in_channels
|
|
|
|
from diffusers import StableDiffusionXLInpaintPipeline
|
|
|
|
|
|
|
|
try:
|
|
|
|
model = StableDiffusionXLInpaintPipeline.from_single_file(
|
|
|
|
model_abs_path,
|
|
|
|
load_safety_checker=False,
|
|
|
|
local_files_only=True,
|
|
|
|
num_in_channels=9,
|
2024-02-01 02:01:43 +01:00
|
|
|
config_files=get_config_files(),
|
2024-01-18 13:54:10 +01:00
|
|
|
)
|
|
|
|
if model.unet.config.in_channels == 9:
|
|
|
|
# https://github.com/huggingface/diffusers/issues/6610
|
|
|
|
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
|
|
|
else:
|
|
|
|
model_type = ModelType.DIFFUSERS_SDXL
|
|
|
|
except ValueError as e:
|
|
|
|
if "Trying to set a tensor of shape torch.Size([320, 4, 3, 3])" in str(e):
|
|
|
|
model_type = ModelType.DIFFUSERS_SDXL
|
|
|
|
else:
|
|
|
|
raise e
|
|
|
|
return model_type
|
|
|
|
|
|
|
|
|
2023-12-01 03:15:35 +01:00
|
|
|
def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
|
|
|
|
cache_dir = Path(cache_dir)
|
2023-12-27 15:00:07 +01:00
|
|
|
stable_diffusion_dir = cache_dir / "stable_diffusion"
|
2024-01-18 13:54:10 +01:00
|
|
|
cache_file = stable_diffusion_dir / "iopaint_cache.json"
|
|
|
|
model_type_cache = {}
|
|
|
|
if cache_file.exists():
|
|
|
|
try:
|
|
|
|
with open(cache_file, "r", encoding="utf-8") as f:
|
|
|
|
model_type_cache = json.load(f)
|
|
|
|
assert isinstance(model_type_cache, dict)
|
|
|
|
except:
|
|
|
|
pass
|
|
|
|
|
2023-12-01 03:15:35 +01:00
|
|
|
res = []
|
2023-12-27 15:00:07 +01:00
|
|
|
for it in stable_diffusion_dir.glob(f"*.*"):
|
|
|
|
if it.suffix not in [".safetensors", ".ckpt"]:
|
|
|
|
continue
|
2024-01-18 13:54:10 +01:00
|
|
|
model_abs_path = str(it.absolute())
|
|
|
|
model_type = model_type_cache.get(it.name)
|
|
|
|
if model_type is None:
|
|
|
|
model_type = get_sd_model_type(model_abs_path)
|
|
|
|
model_type_cache[it.name] = model_type
|
2023-12-27 15:00:07 +01:00
|
|
|
res.append(
|
|
|
|
ModelInfo(
|
|
|
|
name=it.name,
|
2024-01-18 13:54:10 +01:00
|
|
|
path=model_abs_path,
|
2023-12-27 15:00:07 +01:00
|
|
|
model_type=model_type,
|
|
|
|
is_single_file_diffusers=True,
|
|
|
|
)
|
|
|
|
)
|
2024-01-18 13:54:10 +01:00
|
|
|
if stable_diffusion_dir.exists():
|
|
|
|
with open(cache_file, "w", encoding="utf-8") as fw:
|
|
|
|
json.dump(model_type_cache, fw, indent=2, ensure_ascii=False)
|
|
|
|
|
|
|
|
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
|
|
|
|
sdxl_cache_file = stable_diffusion_xl_dir / "iopaint_cache.json"
|
|
|
|
sdxl_model_type_cache = {}
|
|
|
|
if sdxl_cache_file.exists():
|
|
|
|
try:
|
|
|
|
with open(sdxl_cache_file, "r", encoding="utf-8") as f:
|
|
|
|
sdxl_model_type_cache = json.load(f)
|
|
|
|
assert isinstance(sdxl_model_type_cache, dict)
|
|
|
|
except:
|
|
|
|
pass
|
2023-12-27 15:00:07 +01:00
|
|
|
|
|
|
|
for it in stable_diffusion_xl_dir.glob(f"*.*"):
|
2023-12-01 03:15:35 +01:00
|
|
|
if it.suffix not in [".safetensors", ".ckpt"]:
|
|
|
|
continue
|
2024-01-18 13:54:10 +01:00
|
|
|
model_abs_path = str(it.absolute())
|
|
|
|
model_type = sdxl_model_type_cache.get(it.name)
|
|
|
|
if model_type is None:
|
|
|
|
model_type = get_sdxl_model_type(model_abs_path)
|
|
|
|
sdxl_model_type_cache[it.name] = model_type
|
|
|
|
if stable_diffusion_xl_dir.exists():
|
|
|
|
with open(sdxl_cache_file, "w", encoding="utf-8") as fw:
|
|
|
|
json.dump(sdxl_model_type_cache, fw, indent=2, ensure_ascii=False)
|
|
|
|
|
2023-12-01 03:15:35 +01:00
|
|
|
res.append(
|
|
|
|
ModelInfo(
|
|
|
|
name=it.name,
|
2024-01-18 13:54:10 +01:00
|
|
|
path=model_abs_path,
|
2023-12-01 03:15:35 +01:00
|
|
|
model_type=model_type,
|
|
|
|
is_single_file_diffusers=True,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
2023-12-25 04:31:49 +01:00
|
|
|
def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]:
|
2023-12-01 03:15:35 +01:00
|
|
|
res = []
|
2024-01-05 08:19:23 +01:00
|
|
|
from iopaint.model import models
|
2023-12-01 03:15:35 +01:00
|
|
|
|
2023-12-25 04:31:49 +01:00
|
|
|
# logger.info(f"Scanning inpaint models in {model_dir}")
|
|
|
|
|
2023-12-01 03:15:35 +01:00
|
|
|
for name, m in models.items():
|
2023-12-24 08:32:27 +01:00
|
|
|
if m.is_erase_model and m.is_downloaded():
|
2023-12-01 03:15:35 +01:00
|
|
|
res.append(
|
|
|
|
ModelInfo(
|
|
|
|
name=name,
|
|
|
|
path=name,
|
|
|
|
model_type=ModelType.INPAINT,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
def scan_models() -> List[ModelInfo]:
|
2024-01-18 13:54:10 +01:00
|
|
|
from huggingface_hub.constants import HF_HUB_CACHE
|
|
|
|
|
2023-12-25 04:31:49 +01:00
|
|
|
model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR)
|
2023-12-01 03:15:35 +01:00
|
|
|
available_models = []
|
2023-12-25 04:31:49 +01:00
|
|
|
available_models.extend(scan_inpaint_models(model_dir))
|
|
|
|
available_models.extend(scan_single_file_diffusion_models(model_dir))
|
2024-01-02 07:34:36 +01:00
|
|
|
cache_dir = Path(HF_HUB_CACHE)
|
2023-12-25 04:31:49 +01:00
|
|
|
# logger.info(f"Scanning diffusers models in {cache_dir}")
|
2023-12-01 03:15:35 +01:00
|
|
|
diffusers_model_names = []
|
|
|
|
for it in cache_dir.glob("**/*/model_index.json"):
|
|
|
|
with open(it, "r", encoding="utf-8") as f:
|
2024-02-01 02:01:43 +01:00
|
|
|
try:
|
|
|
|
data = json.load(f)
|
|
|
|
except:
|
|
|
|
continue
|
|
|
|
|
2023-12-01 03:15:35 +01:00
|
|
|
_class_name = data["_class_name"]
|
|
|
|
name = folder_name_to_show_name(it.parent.parent.parent.name)
|
|
|
|
if name in diffusers_model_names:
|
|
|
|
continue
|
2023-12-27 15:00:07 +01:00
|
|
|
if "PowerPaint" in name:
|
|
|
|
model_type = ModelType.DIFFUSERS_OTHER
|
|
|
|
elif _class_name == DIFFUSERS_SD_CLASS_NAME:
|
2023-12-01 03:15:35 +01:00
|
|
|
model_type = ModelType.DIFFUSERS_SD
|
|
|
|
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
|
|
|
|
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
|
|
|
elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
|
|
|
|
model_type = ModelType.DIFFUSERS_SDXL
|
|
|
|
elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
|
|
|
|
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
|
|
|
|
elif _class_name in [
|
|
|
|
"StableDiffusionInstructPix2PixPipeline",
|
|
|
|
"PaintByExamplePipeline",
|
|
|
|
"KandinskyV22InpaintPipeline",
|
2024-01-21 16:25:50 +01:00
|
|
|
"AnyText",
|
2023-12-01 03:15:35 +01:00
|
|
|
]:
|
|
|
|
model_type = ModelType.DIFFUSERS_OTHER
|
|
|
|
else:
|
|
|
|
continue
|
|
|
|
|
|
|
|
diffusers_model_names.append(name)
|
|
|
|
available_models.append(
|
|
|
|
ModelInfo(
|
|
|
|
name=name,
|
|
|
|
path=name,
|
|
|
|
model_type=model_type,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
return available_models
|