bug: set eval mode during inference
This commit is contained in:
parent
d0351a8603
commit
ffd527c2fd
21
main.py
21
main.py
@ -1,27 +1,22 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import argparse
|
||||||
import io
|
import io
|
||||||
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import argparse
|
|
||||||
from distutils.util import strtobool
|
from distutils.util import strtobool
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from flask import Flask, request, send_file
|
from flask import Flask, request, send_file
|
||||||
from flask_cors import CORS
|
from flask_cors import CORS
|
||||||
from lama_cleaner.helper import (
|
|
||||||
download_model,
|
|
||||||
load_img,
|
|
||||||
norm_img,
|
|
||||||
resize_max_size,
|
|
||||||
numpy_to_bytes,
|
|
||||||
pad_img_to_modulo,
|
|
||||||
)
|
|
||||||
|
|
||||||
import multiprocessing
|
from lama_cleaner.helper import (download_model, load_img, norm_img,
|
||||||
|
numpy_to_bytes, pad_img_to_modulo,
|
||||||
|
resize_max_size)
|
||||||
|
|
||||||
NUM_THREADS = str(multiprocessing.cpu_count())
|
NUM_THREADS = str(multiprocessing.cpu_count())
|
||||||
|
|
||||||
@ -102,7 +97,8 @@ def run(image, mask):
|
|||||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
inpainted_image = model(image, mask)
|
with torch.no_grad():
|
||||||
|
inpainted_image = model(image, mask)
|
||||||
|
|
||||||
print(f"process time: {(time.time() - start)*1000}ms")
|
print(f"process time: {(time.time() - start)*1000}ms")
|
||||||
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||||
@ -135,6 +131,7 @@ def main():
|
|||||||
|
|
||||||
model = torch.jit.load(model_path, map_location="cpu")
|
model = torch.jit.load(model_path, map_location="cpu")
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
app.run(host="0.0.0.0", port=args.port, debug=args.debug)
|
app.run(host="0.0.0.0", port=args.port, debug=args.debug)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user