From ef1179a8587096a2c9c8faf8c41024ed5fb6e6b1 Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 16 Nov 2023 21:45:55 +0800 Subject: [PATCH] add get_diffusers_models --- lama_cleaner/diffusers_utils.py | 29 +++++++++++++++++++++-------- lama_cleaner/server.py | 9 +++++++++ 2 files changed, 30 insertions(+), 8 deletions(-) diff --git a/lama_cleaner/diffusers_utils.py b/lama_cleaner/diffusers_utils.py index d53fa7f..d67c80c 100644 --- a/lama_cleaner/diffusers_utils.py +++ b/lama_cleaner/diffusers_utils.py @@ -7,22 +7,35 @@ def folder_name_to_show_name(name: str) -> str: return name.replace("models--", "").replace("--", "/") -def _scan_models(cache_dir, class_name: str) -> List[str]: +def _scan_models(cache_dir, class_name: List[str]) -> List[str]: cache_dir = Path(cache_dir) res = [] for it in cache_dir.glob("**/*/model_index.json"): with open(it, "r", encoding="utf-8") as f: data = json.load(f) - if data["_class_name"] == class_name: + if data["_class_name"] in class_name: name = folder_name_to_show_name(it.parent.parent.parent.name) if name not in res: res.append(name) return res -def scan_models(cache_dir) -> List[str]: - return _scan_models(cache_dir, "StableDiffusionPipeline") - - -def scan_inpainting_models(cache_dir) -> List[str]: - return _scan_models(cache_dir, "StableDiffusionInpaintPipeline") +def scan_models(cache_dir) -> Dict[str, List[str]]: + return { + "sd": _scan_models(cache_dir, ["StableDiffusionPipeline"]), + "sd_inpaint": _scan_models( + cache_dir, + [ + "StableDiffusionInpaintPipeline", + "StableDiffusionXLInpaintPipeline", + "KandinskyV22InpaintPipeline", + ], + ), + "other": _scan_models( + cache_dir, + [ + "StableDiffusionInstructPix2PixPipeline", + "PaintByExamplePipeline", + ], + ), + } diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index cb762bf..3088e7f 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -2,6 +2,8 @@ import os import hashlib +from lama_cleaner.diffusers_utils import scan_models + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" import imghdr @@ -432,6 +434,13 @@ def get_server_config(): }, 200 +@app.route("/sd_models", methods=["GET"]) +def get_diffusers_models(): + from diffusers.utils import DIFFUSERS_CACHE + + return scan_models(DIFFUSERS_CACHE) + + @app.route("/model") def current_model(): return model.name, 200