add LAMA_MODEL env for loading other lama model
This commit is contained in:
parent
a2036c71a2
commit
05e4c0993d
9
main.py
9
main.py
@ -125,7 +125,14 @@ def main():
|
|||||||
global device
|
global device
|
||||||
args = get_args_parser()
|
args = get_args_parser()
|
||||||
device = torch.device(args.device)
|
device = torch.device(args.device)
|
||||||
model_path = download_model()
|
|
||||||
|
if os.environ.get("LAMA_MODEL"):
|
||||||
|
model_path = os.environ.get("LAMA_MODEL")
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise FileNotFoundError(f"lama torchscript model not found: {model_path}")
|
||||||
|
else:
|
||||||
|
model_path = download_model()
|
||||||
|
|
||||||
model = torch.jit.load(model_path, map_location="cpu")
|
model = torch.jit.load(model_path, map_location="cpu")
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
app.run(host="0.0.0.0", port=args.port, debug=args.debug)
|
app.run(host="0.0.0.0", port=args.port, debug=args.debug)
|
||||||
|
Loading…
Reference in New Issue
Block a user