add get_diffusers_models
This commit is contained in:
parent
1d145d1cd6
commit
ef1179a858
@ -7,22 +7,35 @@ def folder_name_to_show_name(name: str) -> str:
|
||||
return name.replace("models--", "").replace("--", "/")
|
||||
|
||||
|
||||
def _scan_models(cache_dir, class_name: str) -> List[str]:
|
||||
def _scan_models(cache_dir, class_name: List[str]) -> List[str]:
|
||||
cache_dir = Path(cache_dir)
|
||||
res = []
|
||||
for it in cache_dir.glob("**/*/model_index.json"):
|
||||
with open(it, "r", encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
if data["_class_name"] == class_name:
|
||||
if data["_class_name"] in class_name:
|
||||
name = folder_name_to_show_name(it.parent.parent.parent.name)
|
||||
if name not in res:
|
||||
res.append(name)
|
||||
return res
|
||||
|
||||
|
||||
def scan_models(cache_dir) -> List[str]:
|
||||
return _scan_models(cache_dir, "StableDiffusionPipeline")
|
||||
|
||||
|
||||
def scan_inpainting_models(cache_dir) -> List[str]:
|
||||
return _scan_models(cache_dir, "StableDiffusionInpaintPipeline")
|
||||
def scan_models(cache_dir) -> Dict[str, List[str]]:
|
||||
return {
|
||||
"sd": _scan_models(cache_dir, ["StableDiffusionPipeline"]),
|
||||
"sd_inpaint": _scan_models(
|
||||
cache_dir,
|
||||
[
|
||||
"StableDiffusionInpaintPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"KandinskyV22InpaintPipeline",
|
||||
],
|
||||
),
|
||||
"other": _scan_models(
|
||||
cache_dir,
|
||||
[
|
||||
"StableDiffusionInstructPix2PixPipeline",
|
||||
"PaintByExamplePipeline",
|
||||
],
|
||||
),
|
||||
}
|
||||
|
@ -2,6 +2,8 @@
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
from lama_cleaner.diffusers_utils import scan_models
|
||||
|
||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||
|
||||
import imghdr
|
||||
@ -432,6 +434,13 @@ def get_server_config():
|
||||
}, 200
|
||||
|
||||
|
||||
@app.route("/sd_models", methods=["GET"])
|
||||
def get_diffusers_models():
|
||||
from diffusers.utils import DIFFUSERS_CACHE
|
||||
|
||||
return scan_models(DIFFUSERS_CACHE)
|
||||
|
||||
|
||||
@app.route("/model")
|
||||
def current_model():
|
||||
return model.name, 200
|
||||
|
Loading…
Reference in New Issue
Block a user