bug: set eval mode during inference

This commit is contained in:
callmepantoine 2021-12-16 14:29:32 +01:00
parent d0351a8603
commit ffd527c2fd

19
main.py
View File

@ -1,27 +1,22 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse
import io import io
import multiprocessing
import os import os
import time import time
import argparse
from distutils.util import strtobool from distutils.util import strtobool
from typing import Union from typing import Union
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from flask import Flask, request, send_file from flask import Flask, request, send_file
from flask_cors import CORS from flask_cors import CORS
from lama_cleaner.helper import (
download_model,
load_img,
norm_img,
resize_max_size,
numpy_to_bytes,
pad_img_to_modulo,
)
import multiprocessing from lama_cleaner.helper import (download_model, load_img, norm_img,
numpy_to_bytes, pad_img_to_modulo,
resize_max_size)
NUM_THREADS = str(multiprocessing.cpu_count()) NUM_THREADS = str(multiprocessing.cpu_count())
@ -102,6 +97,7 @@ def run(image, mask):
mask = torch.from_numpy(mask).unsqueeze(0).to(device) mask = torch.from_numpy(mask).unsqueeze(0).to(device)
start = time.time() start = time.time()
with torch.no_grad():
inpainted_image = model(image, mask) inpainted_image = model(image, mask)
print(f"process time: {(time.time() - start)*1000}ms") print(f"process time: {(time.time() - start)*1000}ms")
@ -135,6 +131,7 @@ def main():
model = torch.jit.load(model_path, map_location="cpu") model = torch.jit.load(model_path, map_location="cpu")
model = model.to(device) model = model.to(device)
model.eval()
app.run(host="0.0.0.0", port=args.port, debug=args.debug) app.run(host="0.0.0.0", port=args.port, debug=args.debug)