diff --git a/iopaint/model/controlnet.py b/iopaint/model/controlnet.py index 8630c78..3b02b47 100644 --- a/iopaint/model/controlnet.py +++ b/iopaint/model/controlnet.py @@ -48,7 +48,10 @@ class ControlNet(DiffusionInpaintModel): self.controlnet_method = controlnet_method model_kwargs = {**kwargs.get("pipe_components", {})} - if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): + disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get( + "cpu_offload", False + ) + if disable_nsfw_checker: logger.info("Disable Stable Diffusion Model NSFW checker") model_kwargs.update( dict( @@ -87,7 +90,10 @@ class ControlNet(DiffusionInpaintModel): model_kwargs["num_in_channels"] = 9 self.model = PipeClass.from_single_file( - model_info.path, controlnet=controlnet, **model_kwargs + model_info.path, + controlnet=controlnet, + load_safety_checker=not disable_nsfw_checker, + **model_kwargs, ).to(torch_dtype) else: self.model = handle_from_pretrained_exceptions( @@ -117,7 +123,7 @@ class ControlNet(DiffusionInpaintModel): def switch_controlnet_method(self, new_method: str): self.controlnet_method = new_method controlnet = ControlNetModel.from_pretrained( - new_method, torch_dtype=self.torch_dtype, resume_download=True + new_method, resume_download=True ).to(self.model.device) self.model.controlnet = controlnet diff --git a/iopaint/model/sd.py b/iopaint/model/sd.py index ace3932..f9d315c 100644 --- a/iopaint/model/sd.py +++ b/iopaint/model/sd.py @@ -20,7 +20,8 @@ class SD(DiffusionInpaintModel): use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) model_kwargs = {**kwargs.get("pipe_components", {})} - if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): + disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False) + if disable_nsfw_checker: logger.info("Disable Stable Diffusion Model NSFW checker") model_kwargs.update( dict( @@ -37,7 +38,10 @@ class SD(DiffusionInpaintModel): model_kwargs["num_in_channels"] = 9 self.model = StableDiffusionInpaintPipeline.from_single_file( - self.model_id_or_path, dtype=torch_dtype, **model_kwargs + self.model_id_or_path, + dtype=torch_dtype, + load_safety_checker=not disable_nsfw_checker, + **model_kwargs, ) else: self.model = handle_from_pretrained_exceptions( diff --git a/iopaint/model/sdxl.py b/iopaint/model/sdxl.py index 277dec0..521c472 100644 --- a/iopaint/model/sdxl.py +++ b/iopaint/model/sdxl.py @@ -35,6 +35,7 @@ class SDXL(DiffusionInpaintModel): self.model_id_or_path, dtype=torch_dtype, num_in_channels=num_in_channels, + load_safety_checker=False ) else: model_kwargs = {**kwargs.get("pipe_components", {})}