From ccfb077a2e55837a5584ae186ed77fc98ae6a0d2 Mon Sep 17 00:00:00 2001 From: Qing Date: Wed, 10 Jan 2024 21:48:14 +0800 Subject: [PATCH] change mps dtype to fp32 --- iopaint/model/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/iopaint/model/utils.py b/iopaint/model/utils.py index c9b2b93..e1ae4e1 100644 --- a/iopaint/model/utils.py +++ b/iopaint/model/utils.py @@ -1000,7 +1000,10 @@ def get_torch_dtype(device, no_half: bool): device = str(device) use_fp16 = not no_half use_gpu = device == "cuda" - if device in ["cuda", "mps"] and use_fp16: + # https://github.com/huggingface/diffusers/issues/4480 + # pipe.enable_attention_slicing and float16 will cause black output on mps + # if device in ["cuda", "mps"] and use_fp16: + if device in ["cuda"] and use_fp16: return use_gpu, torch.float16 return use_gpu, torch.float32