move time to main
This commit is contained in:
parent
bb6580cc0c
commit
a46424478a
@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
@ -20,7 +19,9 @@ class LaMa:
|
|||||||
if os.environ.get("LAMA_MODEL"):
|
if os.environ.get("LAMA_MODEL"):
|
||||||
model_path = os.environ.get("LAMA_MODEL")
|
model_path = os.environ.get("LAMA_MODEL")
|
||||||
if not os.path.exists(model_path):
|
if not os.path.exists(model_path):
|
||||||
raise FileNotFoundError(f"lama torchscript model not found: {model_path}")
|
raise FileNotFoundError(
|
||||||
|
f"lama torchscript model not found: {model_path}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
model_path = download_model(LAMA_MODEL_URL)
|
model_path = download_model(LAMA_MODEL_URL)
|
||||||
|
|
||||||
@ -45,10 +46,8 @@ class LaMa:
|
|||||||
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
||||||
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
||||||
|
|
||||||
start = time.time()
|
|
||||||
inpainted_image = self.model(image, mask)
|
inpainted_image = self.model(image, mask)
|
||||||
|
|
||||||
print(f"process time: {(time.time() - start) * 1000}ms")
|
|
||||||
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||||
cur_res = cur_res[0:origin_height, 0:origin_width, :]
|
cur_res = cur_res[0:origin_height, 0:origin_width, :]
|
||||||
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
||||||
|
3
main.py
3
main.py
@ -4,6 +4,7 @@ import argparse
|
|||||||
import io
|
import io
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
@ -73,7 +74,9 @@ def process():
|
|||||||
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation)
|
||||||
mask = norm_img(mask)
|
mask = norm_img(mask)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
res_np_img = model(image, mask)
|
res_np_img = model(image, mask)
|
||||||
|
print(f"process time: {(time.time() - start) * 1000}ms")
|
||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user