Qing 2023-03-01 09:17:39 +08:00
parent d1f36dde8e
commit e4bcd6ac00

View File

@ -23,9 +23,7 @@ def md5sum(filename):
def switch_mps_device(model_name, device):
if model_name not in MPS_SUPPORT_MODELS and (
device == "mps" or device == torch.device("mps")
):
if model_name not in MPS_SUPPORT_MODELS and str(device) == "mps":
logger.info(f"{model_name} not support mps, switch to cpu")
return torch.device("cpu")
return device