This commit is contained in:
parent
d1f36dde8e
commit
e4bcd6ac00
@ -23,9 +23,7 @@ def md5sum(filename):
|
|||||||
|
|
||||||
|
|
||||||
def switch_mps_device(model_name, device):
|
def switch_mps_device(model_name, device):
|
||||||
if model_name not in MPS_SUPPORT_MODELS and (
|
if model_name not in MPS_SUPPORT_MODELS and str(device) == "mps":
|
||||||
device == "mps" or device == torch.device("mps")
|
|
||||||
):
|
|
||||||
logger.info(f"{model_name} not support mps, switch to cpu")
|
logger.info(f"{model_name} not support mps, switch to cpu")
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
return device
|
return device
|
||||||
|
Loading…
Reference in New Issue
Block a user