from_single_file load_safety_checker work
This commit is contained in:
parent
772ef65f7b
commit
cbcdf3b9a2
@ -48,7 +48,10 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
self.controlnet_method = controlnet_method
|
self.controlnet_method = controlnet_method
|
||||||
|
|
||||||
model_kwargs = {**kwargs.get("pipe_components", {})}
|
model_kwargs = {**kwargs.get("pipe_components", {})}
|
||||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get(
|
||||||
|
"cpu_offload", False
|
||||||
|
)
|
||||||
|
if disable_nsfw_checker:
|
||||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||||
model_kwargs.update(
|
model_kwargs.update(
|
||||||
dict(
|
dict(
|
||||||
@ -87,7 +90,10 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
model_kwargs["num_in_channels"] = 9
|
model_kwargs["num_in_channels"] = 9
|
||||||
|
|
||||||
self.model = PipeClass.from_single_file(
|
self.model = PipeClass.from_single_file(
|
||||||
model_info.path, controlnet=controlnet, **model_kwargs
|
model_info.path,
|
||||||
|
controlnet=controlnet,
|
||||||
|
load_safety_checker=not disable_nsfw_checker,
|
||||||
|
**model_kwargs,
|
||||||
).to(torch_dtype)
|
).to(torch_dtype)
|
||||||
else:
|
else:
|
||||||
self.model = handle_from_pretrained_exceptions(
|
self.model = handle_from_pretrained_exceptions(
|
||||||
@ -117,7 +123,7 @@ class ControlNet(DiffusionInpaintModel):
|
|||||||
def switch_controlnet_method(self, new_method: str):
|
def switch_controlnet_method(self, new_method: str):
|
||||||
self.controlnet_method = new_method
|
self.controlnet_method = new_method
|
||||||
controlnet = ControlNetModel.from_pretrained(
|
controlnet = ControlNetModel.from_pretrained(
|
||||||
new_method, torch_dtype=self.torch_dtype, resume_download=True
|
new_method, resume_download=True
|
||||||
).to(self.model.device)
|
).to(self.model.device)
|
||||||
self.model.controlnet = controlnet
|
self.model.controlnet = controlnet
|
||||||
|
|
||||||
|
@ -20,7 +20,8 @@ class SD(DiffusionInpaintModel):
|
|||||||
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
use_gpu, torch_dtype = get_torch_dtype(device, kwargs.get("no_half", False))
|
||||||
|
|
||||||
model_kwargs = {**kwargs.get("pipe_components", {})}
|
model_kwargs = {**kwargs.get("pipe_components", {})}
|
||||||
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
disable_nsfw_checker = kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False)
|
||||||
|
if disable_nsfw_checker:
|
||||||
logger.info("Disable Stable Diffusion Model NSFW checker")
|
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||||
model_kwargs.update(
|
model_kwargs.update(
|
||||||
dict(
|
dict(
|
||||||
@ -37,7 +38,10 @@ class SD(DiffusionInpaintModel):
|
|||||||
model_kwargs["num_in_channels"] = 9
|
model_kwargs["num_in_channels"] = 9
|
||||||
|
|
||||||
self.model = StableDiffusionInpaintPipeline.from_single_file(
|
self.model = StableDiffusionInpaintPipeline.from_single_file(
|
||||||
self.model_id_or_path, dtype=torch_dtype, **model_kwargs
|
self.model_id_or_path,
|
||||||
|
dtype=torch_dtype,
|
||||||
|
load_safety_checker=not disable_nsfw_checker,
|
||||||
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.model = handle_from_pretrained_exceptions(
|
self.model = handle_from_pretrained_exceptions(
|
||||||
|
@ -35,6 +35,7 @@ class SDXL(DiffusionInpaintModel):
|
|||||||
self.model_id_or_path,
|
self.model_id_or_path,
|
||||||
dtype=torch_dtype,
|
dtype=torch_dtype,
|
||||||
num_in_channels=num_in_channels,
|
num_in_channels=num_in_channels,
|
||||||
|
load_safety_checker=False
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_kwargs = {**kwargs.get("pipe_components", {})}
|
model_kwargs = {**kwargs.get("pipe_components", {})}
|
||||||
|
Loading…
Reference in New Issue
Block a user