IOPaint/iopaint/download.py

146 lines
4.8 KiB
Python
Raw Normal View History

2023-12-01 03:15:35 +01:00
import json
2023-11-16 14:12:06 +01:00
import os
2023-12-01 03:15:35 +01:00
from typing import List
2023-11-16 14:12:06 +01:00
2024-01-02 07:34:36 +01:00
from huggingface_hub.constants import HF_HUB_CACHE
2023-11-16 14:12:06 +01:00
from loguru import logger
from pathlib import Path
2024-01-05 08:19:23 +01:00
from iopaint.const import (
2023-12-27 15:00:07 +01:00
DEFAULT_MODEL_DIR,
2023-12-01 03:15:35 +01:00
DIFFUSERS_SD_CLASS_NAME,
2023-12-27 15:00:07 +01:00
DIFFUSERS_SD_INPAINT_CLASS_NAME,
2023-12-01 03:15:35 +01:00
DIFFUSERS_SDXL_CLASS_NAME,
2023-12-27 15:00:07 +01:00
DIFFUSERS_SDXL_INPAINT_CLASS_NAME,
2023-12-01 03:15:35 +01:00
)
2024-01-05 08:19:23 +01:00
from iopaint.model_info import ModelInfo, ModelType
2023-12-01 03:15:35 +01:00
2023-11-16 14:12:06 +01:00
2024-01-05 08:38:34 +01:00
def cli_download_model(model: str):
2024-01-05 08:19:23 +01:00
from iopaint.model import models
2024-01-05 09:38:55 +01:00
from iopaint.model.utils import handle_from_pretrained_exceptions
2023-11-16 14:12:06 +01:00
2023-12-27 15:00:07 +01:00
if model in models and models[model].is_erase_model:
2023-11-16 14:12:06 +01:00
logger.info(f"Downloading {model}...")
models[model].download()
logger.info(f"Done.")
else:
logger.info(f"Downloading model from Huggingface: {model}")
2023-12-01 03:15:35 +01:00
from diffusers import DiffusionPipeline
2023-12-27 15:00:07 +01:00
downloaded_path = handle_from_pretrained_exceptions(
DiffusionPipeline.download,
2023-12-01 03:15:35 +01:00
pretrained_model_name=model,
2023-12-27 15:00:07 +01:00
variant="fp16",
2023-12-01 03:15:35 +01:00
resume_download=True,
)
logger.info(f"Done. Downloaded to {downloaded_path}")
def folder_name_to_show_name(name: str) -> str:
return name.replace("models--", "").replace("--", "/")
def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
cache_dir = Path(cache_dir)
2023-12-27 15:00:07 +01:00
stable_diffusion_dir = cache_dir / "stable_diffusion"
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
2023-12-25 04:31:49 +01:00
# logger.info(f"Scanning single file sd/sdxl models in {cache_dir}")
2023-12-01 03:15:35 +01:00
res = []
2023-12-27 15:00:07 +01:00
for it in stable_diffusion_dir.glob(f"*.*"):
if it.suffix not in [".safetensors", ".ckpt"]:
continue
if "inpaint" in str(it).lower():
model_type = ModelType.DIFFUSERS_SD_INPAINT
else:
model_type = ModelType.DIFFUSERS_SD
res.append(
ModelInfo(
name=it.name,
path=str(it.absolute()),
model_type=model_type,
is_single_file_diffusers=True,
)
)
for it in stable_diffusion_xl_dir.glob(f"*.*"):
2023-12-01 03:15:35 +01:00
if it.suffix not in [".safetensors", ".ckpt"]:
continue
if "inpaint" in str(it).lower():
2023-12-27 15:00:07 +01:00
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
2023-12-01 03:15:35 +01:00
else:
2023-12-27 15:00:07 +01:00
model_type = ModelType.DIFFUSERS_SDXL
2023-12-01 03:15:35 +01:00
res.append(
ModelInfo(
name=it.name,
path=str(it.absolute()),
model_type=model_type,
is_single_file_diffusers=True,
)
)
return res
2023-12-25 04:31:49 +01:00
def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]:
2023-12-01 03:15:35 +01:00
res = []
2024-01-05 08:19:23 +01:00
from iopaint.model import models
2023-12-01 03:15:35 +01:00
2023-12-25 04:31:49 +01:00
# logger.info(f"Scanning inpaint models in {model_dir}")
2023-12-01 03:15:35 +01:00
for name, m in models.items():
2023-12-24 08:32:27 +01:00
if m.is_erase_model and m.is_downloaded():
2023-12-01 03:15:35 +01:00
res.append(
ModelInfo(
name=name,
path=name,
model_type=ModelType.INPAINT,
)
)
return res
def scan_models() -> List[ModelInfo]:
2023-12-25 04:31:49 +01:00
model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR)
2023-12-01 03:15:35 +01:00
available_models = []
2023-12-25 04:31:49 +01:00
available_models.extend(scan_inpaint_models(model_dir))
available_models.extend(scan_single_file_diffusion_models(model_dir))
2024-01-02 07:34:36 +01:00
cache_dir = Path(HF_HUB_CACHE)
2023-12-25 04:31:49 +01:00
# logger.info(f"Scanning diffusers models in {cache_dir}")
2023-12-01 03:15:35 +01:00
diffusers_model_names = []
for it in cache_dir.glob("**/*/model_index.json"):
with open(it, "r", encoding="utf-8") as f:
data = json.load(f)
_class_name = data["_class_name"]
name = folder_name_to_show_name(it.parent.parent.parent.name)
if name in diffusers_model_names:
continue
2023-12-27 15:00:07 +01:00
if "PowerPaint" in name:
model_type = ModelType.DIFFUSERS_OTHER
elif _class_name == DIFFUSERS_SD_CLASS_NAME:
2023-12-01 03:15:35 +01:00
model_type = ModelType.DIFFUSERS_SD
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD_INPAINT
elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SDXL
elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
elif _class_name in [
"StableDiffusionInstructPix2PixPipeline",
"PaintByExamplePipeline",
"KandinskyV22InpaintPipeline",
]:
model_type = ModelType.DIFFUSERS_OTHER
else:
continue
diffusers_model_names.append(name)
available_models.append(
ModelInfo(
name=name,
path=name,
model_type=model_type,
)
)
return available_models