add LAMA_MODEL env for loading other lama model

This commit is contained in:
Sanster 2021-12-05 21:32:18 +08:00
parent a2036c71a2
commit 05e4c0993d

View File

@ -125,7 +125,14 @@ def main():
global device
args = get_args_parser()
device = torch.device(args.device)
model_path = download_model()
if os.environ.get("LAMA_MODEL"):
model_path = os.environ.get("LAMA_MODEL")
if not os.path.exists(model_path):
raise FileNotFoundError(f"lama torchscript model not found: {model_path}")
else:
model_path = download_model()
model = torch.jit.load(model_path, map_location="cpu")
model = model.to(device)
app.run(host="0.0.0.0", port=args.port, debug=args.debug)