From 05e4c0993db086613d0708eefc8092966e1f7e03 Mon Sep 17 00:00:00 2001 From: Sanster Date: Sun, 5 Dec 2021 21:32:18 +0800 Subject: [PATCH] add LAMA_MODEL env for loading other lama model --- main.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/main.py b/main.py index 7500240..064532b 100644 --- a/main.py +++ b/main.py @@ -125,7 +125,14 @@ def main(): global device args = get_args_parser() device = torch.device(args.device) - model_path = download_model() + + if os.environ.get("LAMA_MODEL"): + model_path = os.environ.get("LAMA_MODEL") + if not os.path.exists(model_path): + raise FileNotFoundError(f"lama torchscript model not found: {model_path}") + else: + model_path = download_model() + model = torch.jit.load(model_path, map_location="cpu") model = model.to(device) app.run(host="0.0.0.0", port=args.port, debug=args.debug)