support scan converted diffusers

This commit is contained in:
Qing 2024-02-06 16:32:42 +08:00
parent 92bbd82a53
commit 68f54444e0

View File

@ -185,13 +185,10 @@ def scan_inpaint_models(model_dir: Path) -> List[ModelInfo]:
return res return res
def scan_models() -> List[ModelInfo]: def scan_diffusers_models() -> List[ModelInfo]:
from huggingface_hub.constants import HF_HUB_CACHE from huggingface_hub.constants import HF_HUB_CACHE
model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR)
available_models = [] available_models = []
available_models.extend(scan_inpaint_models(model_dir))
available_models.extend(scan_single_file_diffusion_models(model_dir))
cache_dir = Path(HF_HUB_CACHE) cache_dir = Path(HF_HUB_CACHE)
# logger.info(f"Scanning diffusers models in {cache_dir}") # logger.info(f"Scanning diffusers models in {cache_dir}")
diffusers_model_names = [] diffusers_model_names = []
@ -234,5 +231,64 @@ def scan_models() -> List[ModelInfo]:
model_type=model_type, model_type=model_type,
) )
) )
return available_models
def _scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
cache_dir = Path(cache_dir)
available_models = []
diffusers_model_names = []
for it in cache_dir.glob("**/*/model_index.json"):
with open(it, "r", encoding="utf-8") as f:
try:
data = json.load(f)
except:
logger.error(
f"Failed to load {it}, please try revert from original model or fix model_index.json by hand."
)
continue
_class_name = data["_class_name"]
name = folder_name_to_show_name(it.parent.name)
if name in diffusers_model_names:
continue
elif _class_name == DIFFUSERS_SD_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD
elif _class_name == DIFFUSERS_SD_INPAINT_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SD_INPAINT
elif _class_name == DIFFUSERS_SDXL_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SDXL
elif _class_name == DIFFUSERS_SDXL_INPAINT_CLASS_NAME:
model_type = ModelType.DIFFUSERS_SDXL_INPAINT
else:
continue
diffusers_model_names.append(name)
available_models.append(
ModelInfo(
name=name,
path=str(it.parent.absolute()),
model_type=model_type,
)
)
return available_models
def scan_converted_diffusers_models(cache_dir) -> List[ModelInfo]:
cache_dir = Path(cache_dir)
available_models = []
stable_diffusion_dir = cache_dir / "stable_diffusion"
stable_diffusion_xl_dir = cache_dir / "stable_diffusion_xl"
available_models.extend(_scan_converted_diffusers_models(stable_diffusion_dir))
available_models.extend(_scan_converted_diffusers_models(stable_diffusion_xl_dir))
return available_models
def scan_models() -> List[ModelInfo]:
model_dir = os.getenv("XDG_CACHE_HOME", DEFAULT_MODEL_DIR)
available_models = []
available_models.extend(scan_inpaint_models(model_dir))
available_models.extend(scan_single_file_diffusion_models(model_dir))
available_models.extend(scan_diffusers_models())
available_models.extend(scan_converted_diffusers_models(model_dir))
return available_models return available_models