add get_diffusers_models
This commit is contained in:
parent
1d145d1cd6
commit
ef1179a858
@ -7,22 +7,35 @@ def folder_name_to_show_name(name: str) -> str:
|
|||||||
return name.replace("models--", "").replace("--", "/")
|
return name.replace("models--", "").replace("--", "/")
|
||||||
|
|
||||||
|
|
||||||
def _scan_models(cache_dir, class_name: str) -> List[str]:
|
def _scan_models(cache_dir, class_name: List[str]) -> List[str]:
|
||||||
cache_dir = Path(cache_dir)
|
cache_dir = Path(cache_dir)
|
||||||
res = []
|
res = []
|
||||||
for it in cache_dir.glob("**/*/model_index.json"):
|
for it in cache_dir.glob("**/*/model_index.json"):
|
||||||
with open(it, "r", encoding="utf-8") as f:
|
with open(it, "r", encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
if data["_class_name"] == class_name:
|
if data["_class_name"] in class_name:
|
||||||
name = folder_name_to_show_name(it.parent.parent.parent.name)
|
name = folder_name_to_show_name(it.parent.parent.parent.name)
|
||||||
if name not in res:
|
if name not in res:
|
||||||
res.append(name)
|
res.append(name)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def scan_models(cache_dir) -> List[str]:
|
def scan_models(cache_dir) -> Dict[str, List[str]]:
|
||||||
return _scan_models(cache_dir, "StableDiffusionPipeline")
|
return {
|
||||||
|
"sd": _scan_models(cache_dir, ["StableDiffusionPipeline"]),
|
||||||
|
"sd_inpaint": _scan_models(
|
||||||
def scan_inpainting_models(cache_dir) -> List[str]:
|
cache_dir,
|
||||||
return _scan_models(cache_dir, "StableDiffusionInpaintPipeline")
|
[
|
||||||
|
"StableDiffusionInpaintPipeline",
|
||||||
|
"StableDiffusionXLInpaintPipeline",
|
||||||
|
"KandinskyV22InpaintPipeline",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
"other": _scan_models(
|
||||||
|
cache_dir,
|
||||||
|
[
|
||||||
|
"StableDiffusionInstructPix2PixPipeline",
|
||||||
|
"PaintByExamplePipeline",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
import os
|
import os
|
||||||
import hashlib
|
import hashlib
|
||||||
|
|
||||||
|
from lama_cleaner.diffusers_utils import scan_models
|
||||||
|
|
||||||
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
|
|
||||||
import imghdr
|
import imghdr
|
||||||
@ -432,6 +434,13 @@ def get_server_config():
|
|||||||
}, 200
|
}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/sd_models", methods=["GET"])
|
||||||
|
def get_diffusers_models():
|
||||||
|
from diffusers.utils import DIFFUSERS_CACHE
|
||||||
|
|
||||||
|
return scan_models(DIFFUSERS_CACHE)
|
||||||
|
|
||||||
|
|
||||||
@app.route("/model")
|
@app.route("/model")
|
||||||
def current_model():
|
def current_model():
|
||||||
return model.name, 200
|
return model.name, 200
|
||||||
|
Loading…
Reference in New Issue
Block a user