diffusers removed config_files api: https://github.com/huggingface/diffusers/issues/6819#issuecomment-1928396112
This commit is contained in:
parent
f71e9cfb26
commit
35f12d5b9b
@ -25,11 +25,11 @@ def cli_download_model(model: str):
|
||||
if model in models and models[model].is_erase_model:
|
||||
logger.info(f"Downloading {model}...")
|
||||
models[model].download()
|
||||
logger.info(f"Done.")
|
||||
logger.info("Done.")
|
||||
elif model == ANYTEXT_NAME:
|
||||
logger.info(f"Downloading {model}...")
|
||||
models[model].download()
|
||||
logger.info(f"Done.")
|
||||
logger.info("Done.")
|
||||
else:
|
||||
logger.info(f"Downloading model from Huggingface: {model}")
|
||||
from diffusers import DiffusionPipeline
|
||||
@ -60,7 +60,7 @@ def get_sd_model_type(model_abs_path: str) -> ModelType:
|
||||
model_abs_path,
|
||||
load_safety_checker=False,
|
||||
num_in_channels=9,
|
||||
config_files=get_config_files(),
|
||||
original_config_file=get_config_files()['v1']
|
||||
)
|
||||
model_type = ModelType.DIFFUSERS_SD_INPAINT
|
||||
except ValueError as e:
|
||||
@ -84,7 +84,7 @@ def get_sdxl_model_type(model_abs_path: str) -> ModelType:
|
||||
model_abs_path,
|
||||
load_safety_checker=False,
|
||||
num_in_channels=9,
|
||||
config_files=get_config_files(),
|
||||
original_config_file=get_config_files()['xl'],
|
||||
)
|
||||
if model.unet.config.in_channels == 9:
|
||||
# https://github.com/huggingface/diffusers/issues/6610
|
||||
@ -113,7 +113,7 @@ def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
|
||||
pass
|
||||
|
||||
res = []
|
||||
for it in stable_diffusion_dir.glob(f"*.*"):
|
||||
for it in stable_diffusion_dir.glob("*.*"):
|
||||
if it.suffix not in [".safetensors", ".ckpt"]:
|
||||
continue
|
||||
model_abs_path = str(it.absolute())
|
||||
@ -144,7 +144,7 @@ def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
|
||||
except:
|
||||
pass
|
||||
|
||||
for it in stable_diffusion_xl_dir.glob(f"*.*"):
|
||||
for it in stable_diffusion_xl_dir.glob("*.*"):
|
||||
if it.suffix not in [".safetensors", ".ckpt"]:
|
||||
continue
|
||||
model_abs_path = str(it.absolute())
|
||||
|
@ -71,6 +71,7 @@ class ControlNet(DiffusionInpaintModel):
|
||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||
self.torch_dtype = torch_dtype
|
||||
|
||||
original_config_file_name = "v1"
|
||||
if model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SD,
|
||||
ModelType.DIFFUSERS_SD_INPAINT,
|
||||
@ -78,6 +79,8 @@ class ControlNet(DiffusionInpaintModel):
|
||||
from diffusers import (
|
||||
StableDiffusionControlNetInpaintPipeline as PipeClass,
|
||||
)
|
||||
original_config_file_name = "v1"
|
||||
|
||||
elif model_info.model_type in [
|
||||
ModelType.DIFFUSERS_SDXL,
|
||||
ModelType.DIFFUSERS_SDXL_INPAINT,
|
||||
@ -85,6 +88,7 @@ class ControlNet(DiffusionInpaintModel):
|
||||
from diffusers import (
|
||||
StableDiffusionXLControlNetInpaintPipeline as PipeClass,
|
||||
)
|
||||
original_config_file_name = "xl"
|
||||
|
||||
controlnet = ControlNetModel.from_pretrained(
|
||||
pretrained_model_name_or_path=controlnet_method,
|
||||
@ -103,7 +107,7 @@ class ControlNet(DiffusionInpaintModel):
|
||||
controlnet=controlnet,
|
||||
load_safety_checker=not disable_nsfw_checker,
|
||||
torch_dtype=torch_dtype,
|
||||
config_files=get_config_files(),
|
||||
original_config_file=get_config_files()[original_config_file_name],
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
|
@ -52,7 +52,7 @@ class SD(DiffusionInpaintModel):
|
||||
self.model_id_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
load_safety_checker=not disable_nsfw_checker,
|
||||
config_files=get_config_files(),
|
||||
original_config_file=get_config_files()['v1'],
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
|
@ -42,7 +42,7 @@ class SDXL(DiffusionInpaintModel):
|
||||
torch_dtype=torch_dtype,
|
||||
num_in_channels=num_in_channels,
|
||||
load_safety_checker=False,
|
||||
config_files=get_config_files()
|
||||
original_config_file=get_config_files()['xl'],
|
||||
)
|
||||
else:
|
||||
model_kwargs = {
|
||||
|
Loading…
Reference in New Issue
Block a user