Merge pull request #7 from callmepantoine/main

bug fix : set eval mode during inference + don't store gradients
This commit is contained in:
Qing 2021-12-16 22:02:08 +08:00 committed by GitHub
commit a24df020d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

21
main.py
View File

@ -1,27 +1,22 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse
import io import io
import multiprocessing
import os import os
import time import time
import argparse
from distutils.util import strtobool from distutils.util import strtobool
from typing import Union from typing import Union
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from flask import Flask, request, send_file from flask import Flask, request, send_file
from flask_cors import CORS from flask_cors import CORS
from lama_cleaner.helper import (
download_model,
load_img,
norm_img,
resize_max_size,
numpy_to_bytes,
pad_img_to_modulo,
)
import multiprocessing from lama_cleaner.helper import (download_model, load_img, norm_img,
numpy_to_bytes, pad_img_to_modulo,
resize_max_size)
NUM_THREADS = str(multiprocessing.cpu_count()) NUM_THREADS = str(multiprocessing.cpu_count())
@ -102,7 +97,8 @@ def run(image, mask):
mask = torch.from_numpy(mask).unsqueeze(0).to(device) mask = torch.from_numpy(mask).unsqueeze(0).to(device)
start = time.time() start = time.time()
inpainted_image = model(image, mask) with torch.no_grad():
inpainted_image = model(image, mask)
print(f"process time: {(time.time() - start)*1000}ms") print(f"process time: {(time.time() - start)*1000}ms")
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
@ -135,6 +131,7 @@ def main():
model = torch.jit.load(model_path, map_location="cpu") model = torch.jit.load(model_path, map_location="cpu")
model = model.to(device) model = model.to(device)
model.eval()
app.run(host="0.0.0.0", port=args.port, debug=args.debug) app.run(host="0.0.0.0", port=args.port, debug=args.debug)