IOPaint/lama_cleaner/plugins/interactive_seg.py

83 lines
2.8 KiB
Python
Raw Normal View History

2024-01-02 04:07:35 +01:00
import hashlib
2023-03-22 05:57:18 +01:00
import json
2024-01-02 04:07:35 +01:00
from typing import List
2022-11-27 14:25:27 +01:00
import cv2
2023-03-22 05:57:18 +01:00
import numpy as np
2022-11-27 14:25:27 +01:00
from loguru import logger
2023-04-06 15:55:20 +02:00
from lama_cleaner.helper import download_model
2023-03-26 06:37:58 +02:00
from lama_cleaner.plugins.base_plugin import BasePlugin
2023-04-06 15:55:20 +02:00
from lama_cleaner.plugins.segment_anything import SamPredictor, sam_model_registry
2024-01-02 04:07:35 +01:00
from lama_cleaner.schema import RunPluginRequest
2023-04-06 15:55:20 +02:00
# 从小到大
SEGMENT_ANYTHING_MODELS = {
"vit_b": {
"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
"md5": "01ec64d29a2fca3f0661936605ae66f8",
},
"vit_l": {
"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"md5": "0b3195507c641ddb6910d2bb5adee89c",
},
"vit_h": {
"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"md5": "4b8939a88964f0f4ff5f5b2642c598a6",
},
2023-12-24 08:32:27 +01:00
"mobile_sam": {
2023-10-07 06:48:29 +02:00
"url": "https://github.com/Sanster/models/releases/download/MobileSAM/mobile_sam.pt",
"md5": "f3c0d8cda613564d499310dab6c812cd",
},
2023-04-06 15:55:20 +02:00
}
2022-11-27 14:25:27 +01:00
2023-03-26 06:37:58 +02:00
class InteractiveSeg(BasePlugin):
2023-03-22 05:57:18 +01:00
name = "InteractiveSeg"
2023-04-06 15:55:20 +02:00
def __init__(self, model_name, device):
2023-03-26 06:37:58 +02:00
super().__init__()
2023-04-06 15:55:20 +02:00
model_path = download_model(
SEGMENT_ANYTHING_MODELS[model_name]["url"],
SEGMENT_ANYTHING_MODELS[model_name]["md5"],
)
logger.info(f"SegmentAnything model path: {model_path}")
self.predictor = SamPredictor(
sam_model_registry[model_name](checkpoint=model_path).to(device)
2023-03-22 05:57:18 +01:00
)
2023-04-06 15:55:20 +02:00
self.prev_img_md5 = None
2023-03-22 05:57:18 +01:00
2024-01-02 04:07:35 +01:00
def __call__(self, rgb_np_img, req: RunPluginRequest):
img_md5 = hashlib.md5(req.image.encode("utf-8")).hexdigest()
return self.forward(rgb_np_img, req.clicks, img_md5)
2023-04-06 15:55:20 +02:00
2024-01-02 04:07:35 +01:00
def forward(self, rgb_np_img, clicks: List[List], img_md5: str):
2023-04-06 15:55:20 +02:00
input_point = []
input_label = []
for click in clicks:
x = click[0]
y = click[1]
input_point.append([x, y])
input_label.append(click[2])
if img_md5 and img_md5 != self.prev_img_md5:
self.prev_img_md5 = img_md5
self.predictor.set_image(rgb_np_img)
masks, scores, _ = self.predictor.predict(
point_coords=np.array(input_point),
point_labels=np.array(input_label),
multimask_output=False,
)
mask = masks[0].astype(np.uint8) * 255
# TODO: how to set kernel size?
kernel_size = 9
mask = cv2.dilate(
mask, np.ones((kernel_size, kernel_size), np.uint8), iterations=1
)
# fronted brush color "ffcc00bb"
res_mask = np.zeros((mask.shape[0], mask.shape[1], 4), dtype=np.uint8)
res_mask[mask == 255] = [255, 203, 0, int(255 * 0.73)]
res_mask = cv2.cvtColor(res_mask, cv2.COLOR_BGRA2RGBA)
return res_mask