add get_diffusers_models

This commit is contained in:
Qing 2023-11-16 21:45:55 +08:00
parent 1d145d1cd6
commit ef1179a858
2 changed files with 30 additions and 8 deletions

View File

@ -7,22 +7,35 @@ def folder_name_to_show_name(name: str) -> str:
return name.replace("models--", "").replace("--", "/")
def _scan_models(cache_dir, class_name: str) -> List[str]:
def _scan_models(cache_dir, class_name: List[str]) -> List[str]:
cache_dir = Path(cache_dir)
res = []
for it in cache_dir.glob("**/*/model_index.json"):
with open(it, "r", encoding="utf-8") as f:
data = json.load(f)
if data["_class_name"] == class_name:
if data["_class_name"] in class_name:
name = folder_name_to_show_name(it.parent.parent.parent.name)
if name not in res:
res.append(name)
return res
def scan_models(cache_dir) -> List[str]:
return _scan_models(cache_dir, "StableDiffusionPipeline")
def scan_inpainting_models(cache_dir) -> List[str]:
return _scan_models(cache_dir, "StableDiffusionInpaintPipeline")
def scan_models(cache_dir) -> Dict[str, List[str]]:
return {
"sd": _scan_models(cache_dir, ["StableDiffusionPipeline"]),
"sd_inpaint": _scan_models(
cache_dir,
[
"StableDiffusionInpaintPipeline",
"StableDiffusionXLInpaintPipeline",
"KandinskyV22InpaintPipeline",
],
),
"other": _scan_models(
cache_dir,
[
"StableDiffusionInstructPix2PixPipeline",
"PaintByExamplePipeline",
],
),
}

View File

@ -2,6 +2,8 @@
import os
import hashlib
from lama_cleaner.diffusers_utils import scan_models
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import imghdr
@ -432,6 +434,13 @@ def get_server_config():
}, 200
@app.route("/sd_models", methods=["GET"])
def get_diffusers_models():
from diffusers.utils import DIFFUSERS_CACHE
return scan_models(DIFFUSERS_CACHE)
@app.route("/model")
def current_model():
return model.name, 200