add cuda check

This commit is contained in:
Qing 2023-01-14 22:00:34 +08:00
parent 719c7278b4
commit 7d00fc8ceb

View File

@ -75,6 +75,11 @@ def parse_args():
args = parser.parse_args() args = parser.parse_args()
if args.device == "cuda":
import torch
if torch.cuda.is_available() is False:
parser.error("torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation")
if args.model_dir is not None: if args.model_dir is not None:
if os.path.isfile(args.model_dir): if os.path.isfile(args.model_dir):
parser.error(f"invalid --model-dir: {args.model_dir} is a file") parser.error(f"invalid --model-dir: {args.model_dir} is a file")