change mps dtype to fp32
This commit is contained in:
parent
c3cc7a238e
commit
ccfb077a2e
@ -1000,7 +1000,10 @@ def get_torch_dtype(device, no_half: bool):
|
||||
device = str(device)
|
||||
use_fp16 = not no_half
|
||||
use_gpu = device == "cuda"
|
||||
if device in ["cuda", "mps"] and use_fp16:
|
||||
# https://github.com/huggingface/diffusers/issues/4480
|
||||
# pipe.enable_attention_slicing and float16 will cause black output on mps
|
||||
# if device in ["cuda", "mps"] and use_fp16:
|
||||
if device in ["cuda"] and use_fp16:
|
||||
return use_gpu, torch.float16
|
||||
return use_gpu, torch.float32
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user