add cuda check
This commit is contained in:
parent
719c7278b4
commit
7d00fc8ceb
@ -75,6 +75,11 @@ def parse_args():
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.device == "cuda":
|
||||||
|
import torch
|
||||||
|
if torch.cuda.is_available() is False:
|
||||||
|
parser.error("torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation")
|
||||||
|
|
||||||
if args.model_dir is not None:
|
if args.model_dir is not None:
|
||||||
if os.path.isfile(args.model_dir):
|
if os.path.isfile(args.model_dir):
|
||||||
parser.error(f"invalid --model-dir: {args.model_dir} is a file")
|
parser.error(f"invalid --model-dir: {args.model_dir} is a file")
|
||||||
|
Loading…
Reference in New Issue
Block a user