fix torch_dtype & cpu_text_encoder

This commit is contained in:
Qing 2024-02-10 10:59:06 +08:00
parent 180b7d6c70
commit 9aa5a7e0ba
4 changed files with 13 additions and 4 deletions

View File

@ -111,7 +111,7 @@ class ControlNet(DiffusionInpaintModel):
pretrained_model_name_or_path=model_info.path, pretrained_model_name_or_path=model_info.path,
controlnet=controlnet, controlnet=controlnet,
variant="fp16", variant="fp16",
dtype=torch_dtype, torch_dtype=torch_dtype,
**model_kwargs, **model_kwargs,
) )

View File

@ -8,6 +8,7 @@ class CPUTextEncoderWrapper(PreTrainedModel):
def __init__(self, text_encoder, torch_dtype): def __init__(self, text_encoder, torch_dtype):
super().__init__(text_encoder.config) super().__init__(text_encoder.config)
self.config = text_encoder.config self.config = text_encoder.config
self._device = text_encoder.device
# cpu not support float16 # cpu not support float16
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True) self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
@ -30,3 +31,11 @@ class CPUTextEncoderWrapper(PreTrainedModel):
@property @property
def dtype(self): def dtype(self):
return self.torch_dtype return self.torch_dtype
@property
def device(self) -> torch.device:
"""
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
"""
return self._device

View File

@ -50,7 +50,7 @@ class SD(DiffusionInpaintModel):
self.model = StableDiffusionInpaintPipeline.from_single_file( self.model = StableDiffusionInpaintPipeline.from_single_file(
self.model_id_or_path, self.model_id_or_path,
dtype=torch_dtype, torch_dtype=torch_dtype,
load_safety_checker=not disable_nsfw_checker, load_safety_checker=not disable_nsfw_checker,
config_files=get_config_files(), config_files=get_config_files(),
**model_kwargs, **model_kwargs,
@ -60,7 +60,7 @@ class SD(DiffusionInpaintModel):
StableDiffusionInpaintPipeline.from_pretrained, StableDiffusionInpaintPipeline.from_pretrained,
pretrained_model_name_or_path=self.model_id_or_path, pretrained_model_name_or_path=self.model_id_or_path,
variant="fp16", variant="fp16",
dtype=torch_dtype, torch_dtype=torch_dtype,
**model_kwargs, **model_kwargs,
) )

View File

@ -39,7 +39,7 @@ class SDXL(DiffusionInpaintModel):
if os.path.isfile(self.model_id_or_path): if os.path.isfile(self.model_id_or_path):
self.model = StableDiffusionXLInpaintPipeline.from_single_file( self.model = StableDiffusionXLInpaintPipeline.from_single_file(
self.model_id_or_path, self.model_id_or_path,
dtype=torch_dtype, torch_dtype=torch_dtype,
num_in_channels=num_in_channels, num_in_channels=num_in_channels,
load_safety_checker=False, load_safety_checker=False,
config_files=get_config_files() config_files=get_config_files()