This commit is contained in:
parent
d1f36dde8e
commit
e4bcd6ac00
@ -23,9 +23,7 @@ def md5sum(filename):
|
||||
|
||||
|
||||
def switch_mps_device(model_name, device):
|
||||
if model_name not in MPS_SUPPORT_MODELS and (
|
||||
device == "mps" or device == torch.device("mps")
|
||||
):
|
||||
if model_name not in MPS_SUPPORT_MODELS and str(device) == "mps":
|
||||
logger.info(f"{model_name} not support mps, switch to cpu")
|
||||
return torch.device("cpu")
|
||||
return device
|
||||
|
Loading…
Reference in New Issue
Block a user