change mps dtype to fp32

This commit is contained in:
Qing 2024-01-10 21:48:14 +08:00
parent c3cc7a238e
commit ccfb077a2e

View File

@ -1000,7 +1000,10 @@ def get_torch_dtype(device, no_half: bool):
device = str(device) device = str(device)
use_fp16 = not no_half use_fp16 = not no_half
use_gpu = device == "cuda" use_gpu = device == "cuda"
if device in ["cuda", "mps"] and use_fp16: # https://github.com/huggingface/diffusers/issues/4480
# pipe.enable_attention_slicing and float16 will cause black output on mps
# if device in ["cuda", "mps"] and use_fp16:
if device in ["cuda"] and use_fp16:
return use_gpu, torch.float16 return use_gpu, torch.float16
return use_gpu, torch.float32 return use_gpu, torch.float32