IOPaint/lama_cleaner/plugins/interactive_seg.py
2023-03-22 12:57:18 +08:00

264 lines
8.2 KiB
Python

import json
import json
import os
from typing import Tuple, List
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from loguru import logger
from pydantic import BaseModel
from lama_cleaner.helper import (
load_jit_model,
load_img,
)
class Click(BaseModel):
# [y, x]
coords: Tuple[float, float]
is_positive: bool
indx: int
@property
def coords_and_indx(self):
return (*self.coords, self.indx)
def scale(self, x_ratio: float, y_ratio: float) -> "Click":
return Click(
coords=(self.coords[0] * x_ratio, self.coords[1] * y_ratio),
is_positive=self.is_positive,
indx=self.indx,
)
class ResizeTrans:
def __init__(self, size=480):
super().__init__()
self.crop_height = size
self.crop_width = size
def transform(self, image_nd, clicks_lists):
assert image_nd.shape[0] == 1 and len(clicks_lists) == 1
image_height, image_width = image_nd.shape[2:4]
self.image_height = image_height
self.image_width = image_width
image_nd_r = F.interpolate(
image_nd,
(self.crop_height, self.crop_width),
mode="bilinear",
align_corners=True,
)
y_ratio = self.crop_height / image_height
x_ratio = self.crop_width / image_width
clicks_lists_resized = []
for clicks_list in clicks_lists:
clicks_list_resized = [
click.scale(y_ratio, x_ratio) for click in clicks_list
]
clicks_lists_resized.append(clicks_list_resized)
return image_nd_r, clicks_lists_resized
def inv_transform(self, prob_map):
new_prob_map = F.interpolate(
prob_map,
(self.image_height, self.image_width),
mode="bilinear",
align_corners=True,
)
return new_prob_map
class ISPredictor(object):
def __init__(
self,
model,
device,
open_kernel_size: int,
dilate_kernel_size: int,
net_clicks_limit=None,
zoom_in=None,
infer_size=384,
):
self.model = model
self.open_kernel_size = open_kernel_size
self.dilate_kernel_size = dilate_kernel_size
self.net_clicks_limit = net_clicks_limit
self.device = device
self.zoom_in = zoom_in
self.infer_size = infer_size
# self.transforms = [zoom_in] if zoom_in is not None else []
def __call__(self, input_image: torch.Tensor, clicks: List[Click], prev_mask):
"""
Args:
input_image: [1, 3, H, W] [0~1]
clicks: List[Click]
prev_mask: [1, 1, H, W]
Returns:
"""
transforms = [ResizeTrans(self.infer_size)]
input_image = torch.cat((input_image, prev_mask), dim=1)
# image_nd resized to infer_size
for t in transforms:
image_nd, clicks_lists = t.transform(input_image, [clicks])
# image_nd.shape = [1, 4, 256, 256]
# points_nd.sha[e = [1, 2, 3]
# clicks_lists[0][0] Click 类
points_nd = self.get_points_nd(clicks_lists)
pred_logits = self.model(image_nd, points_nd)
pred = torch.sigmoid(pred_logits)
pred = self.post_process(pred)
prediction = F.interpolate(
pred, mode="bilinear", align_corners=True, size=image_nd.size()[2:]
)
for t in reversed(transforms):
prediction = t.inv_transform(prediction)
# if self.zoom_in is not None and self.zoom_in.check_possible_recalculation():
# return self.get_prediction(clicker)
return prediction.cpu().numpy()[0, 0]
def post_process(self, pred: torch.Tensor) -> torch.Tensor:
pred_mask = pred.cpu().numpy()[0][0]
# morph_open to remove small noise
kernel_size = self.open_kernel_size
kernel = cv2.getStructuringElement(
cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)
)
pred_mask = cv2.morphologyEx(pred_mask, cv2.MORPH_OPEN, kernel, iterations=1)
# Why dilate: make region slightly larger to avoid missing some pixels, this generally works better
dilate_kernel_size = self.dilate_kernel_size
if dilate_kernel_size > 1:
kernel = cv2.getStructuringElement(
cv2.MORPH_DILATE, (dilate_kernel_size, dilate_kernel_size)
)
pred_mask = cv2.dilate(pred_mask, kernel, 1)
return torch.from_numpy(pred_mask).unsqueeze(0).unsqueeze(0)
def get_points_nd(self, clicks_lists):
total_clicks = []
num_pos_clicks = [
sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists
]
num_neg_clicks = [
len(clicks_list) - num_pos
for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)
]
num_max_points = max(num_pos_clicks + num_neg_clicks)
if self.net_clicks_limit is not None:
num_max_points = min(self.net_clicks_limit, num_max_points)
num_max_points = max(1, num_max_points)
for clicks_list in clicks_lists:
clicks_list = clicks_list[: self.net_clicks_limit]
pos_clicks = [
click.coords_and_indx for click in clicks_list if click.is_positive
]
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [
(-1, -1, -1)
]
neg_clicks = [
click.coords_and_indx for click in clicks_list if not click.is_positive
]
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [
(-1, -1, -1)
]
total_clicks.append(pos_clicks + neg_clicks)
return torch.tensor(total_clicks, device=self.device)
INTERACTIVE_SEG_MODEL_URL = os.environ.get(
"INTERACTIVE_SEG_MODEL_URL",
"https://github.com/Sanster/models/releases/download/clickseg_pplnet/clickseg_pplnet.pt",
)
INTERACTIVE_SEG_MODEL_MD5 = os.environ.get(
"INTERACTIVE_SEG_MODEL_MD5", "8ca44b6e02bca78f62ec26a3c32376cf"
)
class InteractiveSeg:
name = "InteractiveSeg"
def __init__(self, infer_size=384, open_kernel_size=3, dilate_kernel_size=3):
device = torch.device("cpu")
model = load_jit_model(
INTERACTIVE_SEG_MODEL_URL, device, INTERACTIVE_SEG_MODEL_MD5
).eval()
self.predictor = ISPredictor(
model,
device,
infer_size=infer_size,
open_kernel_size=open_kernel_size,
dilate_kernel_size=dilate_kernel_size,
)
def __call__(self, rgb_np_img, files, form):
image = rgb_np_img
if "mask" in files:
mask, _ = load_img(files["mask"].read(), gray=True)
else:
mask = None
_clicks = json.loads(form["clicks"])
clicks = []
for i, click in enumerate(_clicks):
clicks.append(
Click(coords=(click[1], click[0]), indx=i, is_positive=click[2] == 1)
)
new_mask = self.forward(image, clicks=clicks, prev_mask=mask)
return new_mask
def forward(self, image, clicks, prev_mask=None):
"""
Args:
image: [H,W,C] RGB
clicks:
Returns:
"""
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
image = torch.from_numpy((image / 255).transpose(2, 0, 1)).unsqueeze(0).float()
if prev_mask is None:
mask = torch.zeros_like(image[:, :1, :, :])
else:
logger.info("InteractiveSeg run with prev_mask")
mask = torch.from_numpy(prev_mask / 255).unsqueeze(0).unsqueeze(0).float()
pred_probs = self.predictor(image, clicks, mask)
pred_mask = pred_probs > 0.5
pred_mask = (pred_mask * 255).astype(np.uint8)
# Find largest contour
# pred_mask = only_keep_largest_contour(pred_mask)
# To simplify frontend process, add mask brush color here
fg = pred_mask == 255
bg = pred_mask != 255
pred_mask = cv2.cvtColor(pred_mask, cv2.COLOR_GRAY2BGRA)
# frontend brush color "ffcc00bb"
pred_mask[bg] = 0
pred_mask[fg] = [255, 203, 0, int(255 * 0.73)]
pred_mask = cv2.cvtColor(pred_mask, cv2.COLOR_BGRA2RGBA)
return pred_mask