fix controlnet torch_dtype

This commit is contained in:
Qing 2024-02-10 16:24:40 +08:00
parent 170e7342f4
commit 950fc4ce53

View File

@ -90,6 +90,7 @@ class ControlNet(DiffusionInpaintModel):
pretrained_model_name_or_path=controlnet_method, pretrained_model_name_or_path=controlnet_method,
resume_download=True, resume_download=True,
local_files_only=model_kwargs["local_files_only"], local_files_only=model_kwargs["local_files_only"],
torch_dtype=self.torch_dtype,
) )
if model_info.is_single_file_diffusers: if model_info.is_single_file_diffusers:
if self.model_info.model_type == ModelType.DIFFUSERS_SD: if self.model_info.model_type == ModelType.DIFFUSERS_SD:
@ -133,7 +134,10 @@ class ControlNet(DiffusionInpaintModel):
def switch_controlnet_method(self, new_method: str): def switch_controlnet_method(self, new_method: str):
self.controlnet_method = new_method self.controlnet_method = new_method
controlnet = ControlNetModel.from_pretrained( controlnet = ControlNetModel.from_pretrained(
new_method, resume_download=True, local_files_only=self.local_files_only new_method,
resume_download=True,
local_files_only=self.local_files_only,
torch_dtype=self.torch_dtype,
).to(self.model.device) ).to(self.model.device)
self.model.controlnet = controlnet self.model.controlnet = controlnet