diff --git a/iopaint/model/utils.py b/iopaint/model/utils.py index c9b2b93..e1ae4e1 100644 --- a/iopaint/model/utils.py +++ b/iopaint/model/utils.py @@ -1000,7 +1000,10 @@ def get_torch_dtype(device, no_half: bool): device = str(device) use_fp16 = not no_half use_gpu = device == "cuda" - if device in ["cuda", "mps"] and use_fp16: + # https://github.com/huggingface/diffusers/issues/4480 + # pipe.enable_attention_slicing and float16 will cause black output on mps + # if device in ["cuda", "mps"] and use_fp16: + if device in ["cuda"] and use_fp16: return use_gpu, torch.float16 return use_gpu, torch.float32