From ffd527c2fdb366b241942cd3471e3118cf1d21be Mon Sep 17 00:00:00 2001 From: callmepantoine Date: Thu, 16 Dec 2021 14:29:32 +0100 Subject: [PATCH] bug: set eval mode during inference --- main.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index 064532b..63ec312 100644 --- a/main.py +++ b/main.py @@ -1,27 +1,22 @@ #!/usr/bin/env python3 +import argparse import io +import multiprocessing import os import time -import argparse from distutils.util import strtobool from typing import Union + import cv2 import numpy as np import torch - from flask import Flask, request, send_file from flask_cors import CORS -from lama_cleaner.helper import ( - download_model, - load_img, - norm_img, - resize_max_size, - numpy_to_bytes, - pad_img_to_modulo, -) -import multiprocessing +from lama_cleaner.helper import (download_model, load_img, norm_img, + numpy_to_bytes, pad_img_to_modulo, + resize_max_size) NUM_THREADS = str(multiprocessing.cpu_count()) @@ -102,7 +97,8 @@ def run(image, mask): mask = torch.from_numpy(mask).unsqueeze(0).to(device) start = time.time() - inpainted_image = model(image, mask) + with torch.no_grad(): + inpainted_image = model(image, mask) print(f"process time: {(time.time() - start)*1000}ms") cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() @@ -135,6 +131,7 @@ def main(): model = torch.jit.load(model_path, map_location="cpu") model = model.to(device) + model.eval() app.run(host="0.0.0.0", port=args.port, debug=args.debug)