diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index 3625728..1c607ed 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -75,6 +75,11 @@ def parse_args(): args = parser.parse_args() + if args.device == "cuda": + import torch + if torch.cuda.is_available() is False: + parser.error("torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation") + if args.model_dir is not None: if os.path.isfile(args.model_dir): parser.error(f"invalid --model-dir: {args.model_dir} is a file")