add get_diffusers_models

This commit is contained in:
Qing 2023-11-16 21:45:55 +08:00
parent 1d145d1cd6
commit ef1179a858
2 changed files with 30 additions and 8 deletions

View File

@ -7,22 +7,35 @@ def folder_name_to_show_name(name: str) -> str:
return name.replace("models--", "").replace("--", "/") return name.replace("models--", "").replace("--", "/")
def _scan_models(cache_dir, class_name: str) -> List[str]: def _scan_models(cache_dir, class_name: List[str]) -> List[str]:
cache_dir = Path(cache_dir) cache_dir = Path(cache_dir)
res = [] res = []
for it in cache_dir.glob("**/*/model_index.json"): for it in cache_dir.glob("**/*/model_index.json"):
with open(it, "r", encoding="utf-8") as f: with open(it, "r", encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
if data["_class_name"] == class_name: if data["_class_name"] in class_name:
name = folder_name_to_show_name(it.parent.parent.parent.name) name = folder_name_to_show_name(it.parent.parent.parent.name)
if name not in res: if name not in res:
res.append(name) res.append(name)
return res return res
def scan_models(cache_dir) -> List[str]: def scan_models(cache_dir) -> Dict[str, List[str]]:
return _scan_models(cache_dir, "StableDiffusionPipeline") return {
"sd": _scan_models(cache_dir, ["StableDiffusionPipeline"]),
"sd_inpaint": _scan_models(
def scan_inpainting_models(cache_dir) -> List[str]: cache_dir,
return _scan_models(cache_dir, "StableDiffusionInpaintPipeline") [
"StableDiffusionInpaintPipeline",
"StableDiffusionXLInpaintPipeline",
"KandinskyV22InpaintPipeline",
],
),
"other": _scan_models(
cache_dir,
[
"StableDiffusionInstructPix2PixPipeline",
"PaintByExamplePipeline",
],
),
}

View File

@ -2,6 +2,8 @@
import os import os
import hashlib import hashlib
from lama_cleaner.diffusers_utils import scan_models
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import imghdr import imghdr
@ -432,6 +434,13 @@ def get_server_config():
}, 200 }, 200
@app.route("/sd_models", methods=["GET"])
def get_diffusers_models():
from diffusers.utils import DIFFUSERS_CACHE
return scan_models(DIFFUSERS_CACHE)
@app.route("/model") @app.route("/model")
def current_model(): def current_model():
return model.name, 200 return model.name, 200