diff --git a/iopaint/model/controlnet.py b/iopaint/model/controlnet.py index d3894a6..8630c78 100644 --- a/iopaint/model/controlnet.py +++ b/iopaint/model/controlnet.py @@ -47,7 +47,7 @@ class ControlNet(DiffusionInpaintModel): self.model_info = model_info self.controlnet_method = controlnet_method - model_kwargs = {} + model_kwargs = {**kwargs.get("pipe_components", {})} if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): logger.info("Disable Stable Diffusion Model NSFW checker") model_kwargs.update( diff --git a/iopaint/model/helper/cpu_text_encoder.py b/iopaint/model/helper/cpu_text_encoder.py index 532f6ff..ed01630 100644 --- a/iopaint/model/helper/cpu_text_encoder.py +++ b/iopaint/model/helper/cpu_text_encoder.py @@ -1,10 +1,12 @@ import torch +from transformers import PreTrainedModel + from ..utils import torch_gc -class CPUTextEncoderWrapper(torch.nn.Module): +class CPUTextEncoderWrapper(PreTrainedModel): def __init__(self, text_encoder, torch_dtype): - super().__init__() + super().__init__(text_encoder.config) self.config = text_encoder.config self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True) self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) diff --git a/iopaint/model/sd.py b/iopaint/model/sd.py index eb1a462..ace3932 100644 --- a/iopaint/model/sd.py +++ b/iopaint/model/sd.py @@ -19,7 +19,7 @@ class SD(DiffusionInpaintModel): use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) - model_kwargs = {} + model_kwargs = {**kwargs.get("pipe_components", {})} if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): logger.info("Disable Stable Diffusion Model NSFW checker") model_kwargs.update( diff --git a/iopaint/model/sdxl.py b/iopaint/model/sdxl.py index f54a721..f8795d7 100644 --- a/iopaint/model/sdxl.py +++ b/iopaint/model/sdxl.py @@ -36,15 +36,18 @@ class SDXL(DiffusionInpaintModel): num_in_channels=num_in_channels, ) else: - vae = AutoencoderKL.from_pretrained( - "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype - ) + model_kwargs = {**kwargs.get("pipe_components", {})} + if 'vae' not in model_kwargs: + vae = AutoencoderKL.from_pretrained( + "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch_dtype + ) + model_kwargs['vae'] = vae self.model = handle_from_pretrained_exceptions( StableDiffusionXLInpaintPipeline.from_pretrained, pretrained_model_name_or_path=self.model_id_or_path, torch_dtype=torch_dtype, - vae=vae, variant="fp16", + **model_kwargs ) enable_low_mem(self.model, kwargs.get("low_mem", False)) diff --git a/iopaint/model/utils.py b/iopaint/model/utils.py index 80ecbf9..c9b2b93 100644 --- a/iopaint/model/utils.py +++ b/iopaint/model/utils.py @@ -978,6 +978,7 @@ def handle_from_pretrained_exceptions(func, **kwargs): if "You are trying to load the model files of the `variant=fp16`" in str(e): logger.info("variant=fp16 not found, try revision=fp16") return func(**{**kwargs, "variant": None, "revision": "fp16"}) + raise e except OSError as e: previous_traceback = traceback.format_exc() if "RevisionNotFoundError: 404 Client Error." in previous_traceback: diff --git a/iopaint/model_manager.py b/iopaint/model_manager.py index 9be5c5d..623fb84 100644 --- a/iopaint/model_manager.py +++ b/iopaint/model_manager.py @@ -141,8 +141,19 @@ class ModelManager: self.enable_controlnet = config.enable_controlnet self.controlnet_method = config.controlnet_method + pipe_components = { + "vae": self.model.model.vae, + "text_encoder": self.model.model.text_encoder, + "unet": self.model.model.unet, + } + if hasattr(self.model.model, "text_encoder_2"): + pipe_components["text_encoder_2"] = self.model.model.text_encoder_2 + self.model = self.init_model( - self.name, switch_mps_device(self.name, self.device), **self.kwargs + self.name, + switch_mps_device(self.name, self.device), + pipe_components=pipe_components, + **self.kwargs, ) if not config.enable_controlnet: logger.info(f"Disable controlnet")