diff --git a/main.py b/main.py index 7500240..064532b 100644 --- a/main.py +++ b/main.py @@ -125,7 +125,14 @@ def main(): global device args = get_args_parser() device = torch.device(args.device) - model_path = download_model() + + if os.environ.get("LAMA_MODEL"): + model_path = os.environ.get("LAMA_MODEL") + if not os.path.exists(model_path): + raise FileNotFoundError(f"lama torchscript model not found: {model_path}") + else: + model_path = download_model() + model = torch.jit.load(model_path, map_location="cpu") model = model.to(device) app.run(host="0.0.0.0", port=args.port, debug=args.debug)