diff --git a/lama_cleaner/__init__.py b/lama_cleaner/__init__.py index ff278e0..399fff8 100644 --- a/lama_cleaner/__init__.py +++ b/lama_cleaner/__init__.py @@ -2,8 +2,10 @@ import warnings warnings.simplefilter("ignore", UserWarning) from lama_cleaner.parse_args import parse_args -from lama_cleaner.server import main def entry_point(): args = parse_args() + # To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers + # https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18 + from lama_cleaner.server import main main(args) diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index ec6cc79..3625728 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -65,10 +65,27 @@ def parse_args(): "--output-dir", type=str, help="Only required when --input is directory. Output directory for all processed images" ) + parser.add_argument( + "--model-dir", type=str, default=None, + help="Model download directory (by setting XDG_CACHE_HOME environment variable), " + "by default model downloaded to ~/.cache" + ) parser.add_argument("--disable-model-switch", action="store_true", help="Disable model switch in frontend") parser.add_argument("--debug", action="store_true") args = parser.parse_args() + + if args.model_dir is not None: + if os.path.isfile(args.model_dir): + parser.error(f"invalid --model-dir: {args.model_dir} is a file") + + if not os.path.exists(args.model_dir): + logger.info(f"Create model cache directory: {args.model_dir}") + Path(args.model_dir).mkdir(exist_ok=True, parents=True) + + os.environ["XDG_CACHE_HOME"] = args.model_dir + + if args.input is not None: if not os.path.exists(args.input): parser.error(f"invalid --input: {args.input} not exists")