Qing 2024-04-12 13:09:37 +08:00
parent 21f17c8e2a
commit bfad842a0e
4 changed files with 13 additions and 9 deletions

View File

@ -25,11 +25,11 @@ def cli_download_model(model: str):
if model in models and models[model].is_erase_model: if model in models and models[model].is_erase_model:
logger.info(f"Downloading {model}...") logger.info(f"Downloading {model}...")
models[model].download() models[model].download()
logger.info(f"Done.") logger.info("Done.")
elif model == ANYTEXT_NAME: elif model == ANYTEXT_NAME:
logger.info(f"Downloading {model}...") logger.info(f"Downloading {model}...")
models[model].download() models[model].download()
logger.info(f"Done.") logger.info("Done.")
else: else:
logger.info(f"Downloading model from Huggingface: {model}") logger.info(f"Downloading model from Huggingface: {model}")
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
@ -60,7 +60,7 @@ def get_sd_model_type(model_abs_path: str) -> ModelType:
model_abs_path, model_abs_path,
load_safety_checker=False, load_safety_checker=False,
num_in_channels=9, num_in_channels=9,
config_files=get_config_files(), original_config_file=get_config_files()['v1']
) )
model_type = ModelType.DIFFUSERS_SD_INPAINT model_type = ModelType.DIFFUSERS_SD_INPAINT
except ValueError as e: except ValueError as e:
@ -84,7 +84,7 @@ def get_sdxl_model_type(model_abs_path: str) -> ModelType:
model_abs_path, model_abs_path,
load_safety_checker=False, load_safety_checker=False,
num_in_channels=9, num_in_channels=9,
config_files=get_config_files(), original_config_file=get_config_files()['xl'],
) )
if model.unet.config.in_channels == 9: if model.unet.config.in_channels == 9:
# https://github.com/huggingface/diffusers/issues/6610 # https://github.com/huggingface/diffusers/issues/6610
@ -113,7 +113,7 @@ def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
pass pass
res = [] res = []
for it in stable_diffusion_dir.glob(f"*.*"): for it in stable_diffusion_dir.glob("*.*"):
if it.suffix not in [".safetensors", ".ckpt"]: if it.suffix not in [".safetensors", ".ckpt"]:
continue continue
model_abs_path = str(it.absolute()) model_abs_path = str(it.absolute())
@ -144,7 +144,7 @@ def scan_single_file_diffusion_models(cache_dir) -> List[ModelInfo]:
except: except:
pass pass
for it in stable_diffusion_xl_dir.glob(f"*.*"): for it in stable_diffusion_xl_dir.glob("*.*"):
if it.suffix not in [".safetensors", ".ckpt"]: if it.suffix not in [".safetensors", ".ckpt"]:
continue continue
model_abs_path = str(it.absolute()) model_abs_path = str(it.absolute())

View File

@ -71,6 +71,7 @@ class ControlNet(DiffusionInpaintModel):
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
self.torch_dtype = torch_dtype self.torch_dtype = torch_dtype
original_config_file_name = "v1"
if model_info.model_type in [ if model_info.model_type in [
ModelType.DIFFUSERS_SD, ModelType.DIFFUSERS_SD,
ModelType.DIFFUSERS_SD_INPAINT, ModelType.DIFFUSERS_SD_INPAINT,
@ -78,6 +79,8 @@ class ControlNet(DiffusionInpaintModel):
from diffusers import ( from diffusers import (
StableDiffusionControlNetInpaintPipeline as PipeClass, StableDiffusionControlNetInpaintPipeline as PipeClass,
) )
original_config_file_name = "v1"
elif model_info.model_type in [ elif model_info.model_type in [
ModelType.DIFFUSERS_SDXL, ModelType.DIFFUSERS_SDXL,
ModelType.DIFFUSERS_SDXL_INPAINT, ModelType.DIFFUSERS_SDXL_INPAINT,
@ -85,6 +88,7 @@ class ControlNet(DiffusionInpaintModel):
from diffusers import ( from diffusers import (
StableDiffusionXLControlNetInpaintPipeline as PipeClass, StableDiffusionXLControlNetInpaintPipeline as PipeClass,
) )
original_config_file_name = "xl"
controlnet = ControlNetModel.from_pretrained( controlnet = ControlNetModel.from_pretrained(
pretrained_model_name_or_path=controlnet_method, pretrained_model_name_or_path=controlnet_method,
@ -103,7 +107,7 @@ class ControlNet(DiffusionInpaintModel):
controlnet=controlnet, controlnet=controlnet,
load_safety_checker=not disable_nsfw_checker, load_safety_checker=not disable_nsfw_checker,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
config_files=get_config_files(), original_config_file=get_config_files()[original_config_file_name],
**model_kwargs, **model_kwargs,
) )
else: else:

View File

@ -52,7 +52,7 @@ class SD(DiffusionInpaintModel):
self.model_id_or_path, self.model_id_or_path,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
load_safety_checker=not disable_nsfw_checker, load_safety_checker=not disable_nsfw_checker,
config_files=get_config_files(), original_config_file=get_config_files()['v1'],
**model_kwargs, **model_kwargs,
) )
else: else:

View File

@ -42,7 +42,7 @@ class SDXL(DiffusionInpaintModel):
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
num_in_channels=num_in_channels, num_in_channels=num_in_channels,
load_safety_checker=False, load_safety_checker=False,
config_files=get_config_files() original_config_file=get_config_files()['xl'],
) )
else: else:
model_kwargs = { model_kwargs = {