diff --git a/iopaint/model/controlnet.py b/iopaint/model/controlnet.py index c738b13..b908674 100644 --- a/iopaint/model/controlnet.py +++ b/iopaint/model/controlnet.py @@ -111,7 +111,7 @@ class ControlNet(DiffusionInpaintModel): pretrained_model_name_or_path=model_info.path, controlnet=controlnet, variant="fp16", - dtype=torch_dtype, + torch_dtype=torch_dtype, **model_kwargs, ) diff --git a/iopaint/model/helper/cpu_text_encoder.py b/iopaint/model/helper/cpu_text_encoder.py index 889bc97..116eb48 100644 --- a/iopaint/model/helper/cpu_text_encoder.py +++ b/iopaint/model/helper/cpu_text_encoder.py @@ -8,6 +8,7 @@ class CPUTextEncoderWrapper(PreTrainedModel): def __init__(self, text_encoder, torch_dtype): super().__init__(text_encoder.config) self.config = text_encoder.config + self._device = text_encoder.device # cpu not support float16 self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True) self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) @@ -30,3 +31,11 @@ class CPUTextEncoderWrapper(PreTrainedModel): @property def dtype(self): return self.torch_dtype + + @property + def device(self) -> torch.device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + return self._device \ No newline at end of file diff --git a/iopaint/model/sd.py b/iopaint/model/sd.py index 4f20a41..8f42fff 100644 --- a/iopaint/model/sd.py +++ b/iopaint/model/sd.py @@ -50,7 +50,7 @@ class SD(DiffusionInpaintModel): self.model = StableDiffusionInpaintPipeline.from_single_file( self.model_id_or_path, - dtype=torch_dtype, + torch_dtype=torch_dtype, load_safety_checker=not disable_nsfw_checker, config_files=get_config_files(), **model_kwargs, @@ -60,7 +60,7 @@ class SD(DiffusionInpaintModel): StableDiffusionInpaintPipeline.from_pretrained, pretrained_model_name_or_path=self.model_id_or_path, variant="fp16", - dtype=torch_dtype, + torch_dtype=torch_dtype, **model_kwargs, ) diff --git a/iopaint/model/sdxl.py b/iopaint/model/sdxl.py index 2557e71..29312b1 100644 --- a/iopaint/model/sdxl.py +++ b/iopaint/model/sdxl.py @@ -39,7 +39,7 @@ class SDXL(DiffusionInpaintModel): if os.path.isfile(self.model_id_or_path): self.model = StableDiffusionXLInpaintPipeline.from_single_file( self.model_id_or_path, - dtype=torch_dtype, + torch_dtype=torch_dtype, num_in_channels=num_in_channels, load_safety_checker=False, config_files=get_config_files()