Merge pull request #7 from callmepantoine/main
bug fix : set eval mode during inference + don't store gradients
This commit is contained in:
commit
a24df020d3
19
main.py
19
main.py
@ -1,27 +1,22 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import io
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
from distutils.util import strtobool
|
||||
from typing import Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from flask import Flask, request, send_file
|
||||
from flask_cors import CORS
|
||||
from lama_cleaner.helper import (
|
||||
download_model,
|
||||
load_img,
|
||||
norm_img,
|
||||
resize_max_size,
|
||||
numpy_to_bytes,
|
||||
pad_img_to_modulo,
|
||||
)
|
||||
|
||||
import multiprocessing
|
||||
from lama_cleaner.helper import (download_model, load_img, norm_img,
|
||||
numpy_to_bytes, pad_img_to_modulo,
|
||||
resize_max_size)
|
||||
|
||||
NUM_THREADS = str(multiprocessing.cpu_count())
|
||||
|
||||
@ -102,6 +97,7 @@ def run(image, mask):
|
||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||
|
||||
start = time.time()
|
||||
with torch.no_grad():
|
||||
inpainted_image = model(image, mask)
|
||||
|
||||
print(f"process time: {(time.time() - start)*1000}ms")
|
||||
@ -135,6 +131,7 @@ def main():
|
||||
|
||||
model = torch.jit.load(model_path, map_location="cpu")
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
app.run(host="0.0.0.0", port=args.port, debug=args.debug)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user