from_single_file load_safety_checker work

This commit is contained in:
Qing 2024-01-10 21:22:59 +08:00
parent 772ef65f7b
commit cbcdf3b9a2
3 changed files with 16 additions and 5 deletions

View File

@ -48,7 +48,10 @@ class ControlNet(DiffusionInpaintModel):
self.controlnet_method = controlnet_method self.controlnet_method = controlnet_method
model_kwargs = {**kwargs.get("pipe_components", {})} model_kwargs = {**kwargs.get("pipe_components", {})}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
"cpu_offload", False
)
if disable_nsfw_checker:
logger.info("Disable Stable Diffusion Model NSFW checker") logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update( model_kwargs.update(
dict( dict(
@ -87,7 +90,10 @@ class ControlNet(DiffusionInpaintModel):
model_kwargs["num_in_channels"] = 9 model_kwargs["num_in_channels"] = 9
self.model = PipeClass.from_single_file( self.model = PipeClass.from_single_file(
model_info.path, controlnet=controlnet, **model_kwargs model_info.path,
controlnet=controlnet,
load_safety_checker=not disable_nsfw_checker,
**model_kwargs,
).to(torch_dtype) ).to(torch_dtype)
else: else:
self.model = handle_from_pretrained_exceptions( self.model = handle_from_pretrained_exceptions(
@ -117,7 +123,7 @@ class ControlNet(DiffusionInpaintModel):
def switch_controlnet_method(self, new_method: str): def switch_controlnet_method(self, new_method: str):
self.controlnet_method = new_method self.controlnet_method = new_method
controlnet = ControlNetModel.from_pretrained( controlnet = ControlNetModel.from_pretrained(
new_method, torch_dtype=self.torch_dtype, resume_download=True new_method, resume_download=True
).to(self.model.device) ).to(self.model.device)
self.model.controlnet = controlnet self.model.controlnet = controlnet

View File

@ -20,7 +20,8 @@ class SD(DiffusionInpaintModel):
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False)) use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
model_kwargs = {**kwargs.get("pipe_components", {})} model_kwargs = {**kwargs.get("pipe_components", {})}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False)
if disable_nsfw_checker:
logger.info("Disable Stable Diffusion Model NSFW checker") logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update( model_kwargs.update(
dict( dict(
@ -37,7 +38,10 @@ class SD(DiffusionInpaintModel):
model_kwargs["num_in_channels"] = 9 model_kwargs["num_in_channels"] = 9
self.model = StableDiffusionInpaintPipeline.from_single_file( self.model = StableDiffusionInpaintPipeline.from_single_file(
self.model_id_or_path, dtype=torch_dtype, **model_kwargs self.model_id_or_path,
dtype=torch_dtype,
load_safety_checker=not disable_nsfw_checker,
**model_kwargs,
) )
else: else:
self.model = handle_from_pretrained_exceptions( self.model = handle_from_pretrained_exceptions(

View File

@ -35,6 +35,7 @@ class SDXL(DiffusionInpaintModel):
self.model_id_or_path, self.model_id_or_path,
dtype=torch_dtype, dtype=torch_dtype,
num_in_channels=num_in_channels, num_in_channels=num_in_channels,
load_safety_checker=False
) )
else: else:
model_kwargs = {**kwargs.get("pipe_components", {})} model_kwargs = {**kwargs.get("pipe_components", {})}