change mps dtype to fp32
This commit is contained in:
parent
c3cc7a238e
commit
ccfb077a2e
@ -1000,7 +1000,10 @@ def get_torch_dtype(device, no_half: bool):
|
|||||||
device = str(device)
|
device = str(device)
|
||||||
use_fp16 = not no_half
|
use_fp16 = not no_half
|
||||||
use_gpu = device == "cuda"
|
use_gpu = device == "cuda"
|
||||||
if device in ["cuda", "mps"] and use_fp16:
|
# https://github.com/huggingface/diffusers/issues/4480
|
||||||
|
# pipe.enable_attention_slicing and float16 will cause black output on mps
|
||||||
|
# if device in ["cuda", "mps"] and use_fp16:
|
||||||
|
if device in ["cuda"] and use_fp16:
|
||||||
return use_gpu, torch.float16
|
return use_gpu, torch.float16
|
||||||
return use_gpu, torch.float32
|
return use_gpu, torch.float32
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user