move time to main

This commit is contained in:
Sanster 2022-03-20 22:42:59 +08:00
parent bb6580cc0c
commit a46424478a
2 changed files with 6 additions and 4 deletions

View File

@ -1,5 +1,4 @@
import os import os
import time
import cv2 import cv2
import torch import torch
@ -20,7 +19,9 @@ class LaMa:
if os.environ.get("LAMA_MODEL"): if os.environ.get("LAMA_MODEL"):
model_path = os.environ.get("LAMA_MODEL") model_path = os.environ.get("LAMA_MODEL")
if not os.path.exists(model_path): if not os.path.exists(model_path):
raise FileNotFoundError(f"lama torchscript model not found: {model_path}") raise FileNotFoundError(
f"lama torchscript model not found: {model_path}"
)
else: else:
model_path = download_model(LAMA_MODEL_URL) model_path = download_model(LAMA_MODEL_URL)
@ -45,10 +46,8 @@ class LaMa:
image = torch.from_numpy(image).unsqueeze(0).to(device) image = torch.from_numpy(image).unsqueeze(0).to(device)
mask = torch.from_numpy(mask).unsqueeze(0).to(device) mask = torch.from_numpy(mask).unsqueeze(0).to(device)
start = time.time()
inpainted_image = self.model(image, mask) inpainted_image = self.model(image, mask)
print(f"process time: {(time.time() - start) * 1000}ms")
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
cur_res = cur_res[0:origin_height, 0:origin_width, :] cur_res = cur_res[0:origin_height, 0:origin_width, :]
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")

View File

@ -4,6 +4,7 @@ import argparse
import io import io
import multiprocessing import multiprocessing
import os import os
import time
from typing import Union from typing import Union
import cv2 import cv2
@ -73,7 +74,9 @@ def process():
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
mask = norm_img(mask) mask = norm_img(mask)
start = time.time()
res_np_img = model(image, mask) res_np_img = model(image, mask)
print(f"process time: {(time.time() - start) * 1000}ms")
torch.cuda.empty_cache() torch.cuda.empty_cache()