2022-03-04 06:44:53 +01:00
|
|
|
import os
|
2022-03-23 03:02:01 +01:00
|
|
|
from typing import List
|
2022-03-04 06:44:53 +01:00
|
|
|
|
|
|
|
import cv2
|
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
|
|
|
2022-03-23 03:02:01 +01:00
|
|
|
from lama_cleaner.helper import pad_img_to_modulo, download_model, boxes_from_mask
|
2022-03-04 06:44:53 +01:00
|
|
|
|
|
|
|
LAMA_MODEL_URL = os.environ.get(
|
|
|
|
"LAMA_MODEL_URL",
|
|
|
|
"https://github.com/Sanster/models/releases/download/add_big_lama/big-lama.pt",
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class LaMa:
|
2022-03-24 02:08:49 +01:00
|
|
|
def __init__(self, crop_trigger_size: List[int], crop_margin: int, device):
|
2022-03-23 03:02:01 +01:00
|
|
|
"""
|
|
|
|
|
|
|
|
Args:
|
|
|
|
crop_trigger_size: h, w
|
2022-03-24 02:08:49 +01:00
|
|
|
crop_margin:
|
2022-03-23 03:02:01 +01:00
|
|
|
device:
|
|
|
|
"""
|
|
|
|
self.crop_trigger_size = crop_trigger_size
|
2022-03-24 02:08:49 +01:00
|
|
|
self.crop_margin = crop_margin
|
2022-03-04 06:44:53 +01:00
|
|
|
self.device = device
|
|
|
|
|
|
|
|
if os.environ.get("LAMA_MODEL"):
|
|
|
|
model_path = os.environ.get("LAMA_MODEL")
|
|
|
|
if not os.path.exists(model_path):
|
2022-03-20 15:42:59 +01:00
|
|
|
raise FileNotFoundError(
|
|
|
|
f"lama torchscript model not found: {model_path}"
|
|
|
|
)
|
2022-03-04 06:44:53 +01:00
|
|
|
else:
|
|
|
|
model_path = download_model(LAMA_MODEL_URL)
|
|
|
|
|
|
|
|
model = torch.jit.load(model_path, map_location="cpu")
|
|
|
|
model = model.to(device)
|
|
|
|
model.eval()
|
|
|
|
self.model = model
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def __call__(self, image, mask):
|
2022-03-23 03:02:01 +01:00
|
|
|
"""
|
|
|
|
image: [C, H, W] RGB
|
|
|
|
mask: [1, H, W]
|
|
|
|
return: BGR IMAGE
|
|
|
|
"""
|
|
|
|
area = image.shape[1] * image.shape[2]
|
|
|
|
if area < self.crop_trigger_size[0] * self.crop_trigger_size[1]:
|
|
|
|
return self._run(image, mask)
|
|
|
|
|
|
|
|
print("Trigger crop image")
|
|
|
|
boxes = boxes_from_mask(mask)
|
|
|
|
crop_result = []
|
|
|
|
for box in boxes:
|
|
|
|
crop_image, crop_box = self._run_box(image, mask, box)
|
|
|
|
crop_result.append((crop_image, crop_box))
|
|
|
|
|
|
|
|
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)[:, :, ::-1]
|
|
|
|
for crop_image, crop_box in crop_result:
|
|
|
|
x1, y1, x2, y2 = crop_box
|
|
|
|
image[y1:y2, x1:x2, :] = crop_image
|
|
|
|
return image
|
|
|
|
|
|
|
|
def _run_box(self, image, mask, box):
|
|
|
|
"""
|
|
|
|
|
|
|
|
Args:
|
|
|
|
image: [C, H, W] RGB
|
|
|
|
mask: [1, H, W]
|
|
|
|
box: [left,top,right,bottom]
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
BGR IMAGE
|
|
|
|
"""
|
|
|
|
box_h = box[3] - box[1]
|
|
|
|
box_w = box[2] - box[0]
|
|
|
|
cx = (box[0] + box[2]) // 2
|
|
|
|
cy = (box[1] + box[3]) // 2
|
|
|
|
img_h, img_w = image.shape[1:]
|
|
|
|
|
2022-03-24 02:08:49 +01:00
|
|
|
w = box_w + self.crop_margin * 2
|
|
|
|
h = box_h + self.crop_margin * 2
|
2022-03-23 03:02:01 +01:00
|
|
|
|
|
|
|
l = max(cx - w // 2, 0)
|
|
|
|
t = max(cy - h // 2, 0)
|
|
|
|
r = min(cx + w // 2, img_w)
|
|
|
|
b = min(cy + h // 2, img_h)
|
|
|
|
|
|
|
|
crop_img = image[:, t:b, l:r]
|
|
|
|
crop_mask = mask[:, t:b, l:r]
|
|
|
|
|
2022-03-24 02:08:49 +01:00
|
|
|
print(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
|
2022-03-23 03:02:01 +01:00
|
|
|
|
|
|
|
return self._run(crop_img, crop_mask), [l, t, r, b]
|
|
|
|
|
|
|
|
def _run(self, image, mask):
|
2022-03-04 06:44:53 +01:00
|
|
|
"""
|
|
|
|
image: [C, H, W] RGB
|
|
|
|
mask: [1, H, W]
|
|
|
|
return: BGR IMAGE
|
|
|
|
"""
|
|
|
|
device = self.device
|
|
|
|
origin_height, origin_width = image.shape[1:]
|
|
|
|
image = pad_img_to_modulo(image, mod=8)
|
|
|
|
mask = pad_img_to_modulo(mask, mod=8)
|
|
|
|
|
|
|
|
mask = (mask > 0) * 1
|
|
|
|
image = torch.from_numpy(image).unsqueeze(0).to(device)
|
|
|
|
mask = torch.from_numpy(mask).unsqueeze(0).to(device)
|
|
|
|
|
|
|
|
inpainted_image = self.model(image, mask)
|
|
|
|
|
|
|
|
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
|
|
|
cur_res = cur_res[0:origin_height, 0:origin_width, :]
|
|
|
|
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
2022-03-23 03:02:01 +01:00
|
|
|
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
2022-03-04 06:44:53 +01:00
|
|
|
return cur_res
|