diff --git a/lama_cleaner/lama/__init__.py b/lama_cleaner/lama/__init__.py index 2fc0a60..3075e15 100644 --- a/lama_cleaner/lama/__init__.py +++ b/lama_cleaner/lama/__init__.py @@ -1,5 +1,4 @@ import os -import time import cv2 import torch @@ -20,7 +19,9 @@ class LaMa: if os.environ.get("LAMA_MODEL"): model_path = os.environ.get("LAMA_MODEL") if not os.path.exists(model_path): - raise FileNotFoundError(f"lama torchscript model not found: {model_path}") + raise FileNotFoundError( + f"lama torchscript model not found: {model_path}" + ) else: model_path = download_model(LAMA_MODEL_URL) @@ -45,10 +46,8 @@ class LaMa: image = torch.from_numpy(image).unsqueeze(0).to(device) mask = torch.from_numpy(mask).unsqueeze(0).to(device) - start = time.time() inpainted_image = self.model(image, mask) - print(f"process time: {(time.time() - start) * 1000}ms") cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() cur_res = cur_res[0:origin_height, 0:origin_width, :] cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") diff --git a/main.py b/main.py index 8160f73..18968a1 100644 --- a/main.py +++ b/main.py @@ -4,6 +4,7 @@ import argparse import io import multiprocessing import os +import time from typing import Union import cv2 @@ -73,7 +74,9 @@ def process(): mask = resize_max_size(mask, size_limit=size_limit, interpolation=interpolation) mask = norm_img(mask) + start = time.time() res_np_img = model(image, mask) + print(f"process time: {(time.time() - start) * 1000}ms") torch.cuda.empty_cache()