2023-11-16 14:12:06 +01:00
|
|
|
import json
|
|
|
|
from pathlib import Path
|
|
|
|
from typing import Dict, List
|
|
|
|
|
|
|
|
|
|
|
|
def folder_name_to_show_name(name: str) -> str:
|
|
|
|
return name.replace("models--", "").replace("--", "/")
|
|
|
|
|
|
|
|
|
2023-11-16 14:45:55 +01:00
|
|
|
def _scan_models(cache_dir, class_name: List[str]) -> List[str]:
|
2023-11-16 14:12:06 +01:00
|
|
|
cache_dir = Path(cache_dir)
|
|
|
|
res = []
|
|
|
|
for it in cache_dir.glob("**/*/model_index.json"):
|
|
|
|
with open(it, "r", encoding="utf-8") as f:
|
|
|
|
data = json.load(f)
|
2023-11-16 14:45:55 +01:00
|
|
|
if data["_class_name"] in class_name:
|
2023-11-16 14:12:06 +01:00
|
|
|
name = folder_name_to_show_name(it.parent.parent.parent.name)
|
|
|
|
if name not in res:
|
|
|
|
res.append(name)
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
2023-11-16 14:45:55 +01:00
|
|
|
def scan_models(cache_dir) -> Dict[str, List[str]]:
|
|
|
|
return {
|
|
|
|
"sd": _scan_models(cache_dir, ["StableDiffusionPipeline"]),
|
|
|
|
"sd_inpaint": _scan_models(
|
|
|
|
cache_dir,
|
|
|
|
[
|
|
|
|
"StableDiffusionInpaintPipeline",
|
|
|
|
"StableDiffusionXLInpaintPipeline",
|
|
|
|
"KandinskyV22InpaintPipeline",
|
|
|
|
],
|
|
|
|
),
|
|
|
|
"other": _scan_models(
|
|
|
|
cache_dir,
|
|
|
|
[
|
|
|
|
"StableDiffusionInstructPix2PixPipeline",
|
|
|
|
"PaintByExamplePipeline",
|
|
|
|
],
|
|
|
|
),
|
|
|
|
}
|