Merge pull request #19 from Sanster/add_crop_infer
add crop infor for lama
This commit is contained in:
commit
1207b6e291
@ -20,6 +20,10 @@ Install requirements: `pip3 install -r requirements.txt`
|
|||||||
python3 main.py --device=cuda --port=8080 --model=lama
|
python3 main.py --device=cuda --port=8080 --model=lama
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- `--crop-trigger-size`: If image size large then crop-trigger-size, crop each area from original image to do inference.
|
||||||
|
Mainly for performance and memory reasons on **very** large image.Default is 2042,2042
|
||||||
|
- `--crop-size`: Crop size for `--crop-trigger-size`. Default is 512,512.
|
||||||
|
|
||||||
### Start server with LDM model
|
### Start server with LDM model
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@ -35,7 +39,6 @@ results than LaMa.
|
|||||||
|--------------|------|----|
|
|--------------|------|----|
|
||||||
|![photo-1583445095369-9c651e7e5d34](https://user-images.githubusercontent.com/3998421/156923525-d6afdec3-7b98-403f-ad20-88ebc6eb8d6d.jpg)|![photo-1583445095369-9c651e7e5d34_cleanup_lama](https://user-images.githubusercontent.com/3998421/156923620-a40cc066-fd4a-4d85-a29f-6458711d1247.png)|![photo-1583445095369-9c651e7e5d34_cleanup_ldm](https://user-images.githubusercontent.com/3998421/156923652-0d06c8c8-33ad-4a42-a717-9c99f3268933.png)|
|
|![photo-1583445095369-9c651e7e5d34](https://user-images.githubusercontent.com/3998421/156923525-d6afdec3-7b98-403f-ad20-88ebc6eb8d6d.jpg)|![photo-1583445095369-9c651e7e5d34_cleanup_lama](https://user-images.githubusercontent.com/3998421/156923620-a40cc066-fd4a-4d85-a29f-6458711d1247.png)|![photo-1583445095369-9c651e7e5d34_cleanup_ldm](https://user-images.githubusercontent.com/3998421/156923652-0d06c8c8-33ad-4a42-a717-9c99f3268933.png)|
|
||||||
|
|
||||||
|
|
||||||
Blogs about diffusion models:
|
Blogs about diffusion models:
|
||||||
|
|
||||||
- https://lilianweng.github.io/posts/2021-07-11-diffusion-models/
|
- https://lilianweng.github.io/posts/2021-07-11-diffusion-models/
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
import cv2
|
import cv2
|
||||||
@ -80,3 +81,27 @@ def pad_img_to_modulo(img, mod):
|
|||||||
((0, 0), (0, out_height - height), (0, out_width - width)),
|
((0, 0), (0, out_height - height), (0, out_width - width)),
|
||||||
mode="symmetric",
|
mode="symmetric",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
mask: (1, h, w) 0~1
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
height, width = mask.shape[1:]
|
||||||
|
_, thresh = cv2.threshold((mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0)
|
||||||
|
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||||
|
|
||||||
|
boxes = []
|
||||||
|
for cnt in contours:
|
||||||
|
x, y, w, h = cv2.boundingRect(cnt)
|
||||||
|
box = np.array([x, y, x + w, y + h]).astype(np.int)
|
||||||
|
|
||||||
|
box[::2] = np.clip(box[::2], 0, width)
|
||||||
|
box[1::2] = np.clip(box[1::2], 0, height)
|
||||||
|
boxes.append(box)
|
||||||
|
|
||||||
|
return boxes
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
import os
|
import os
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from lama_cleaner.helper import pad_img_to_modulo, download_model
|
from lama_cleaner.helper import pad_img_to_modulo, download_model, boxes_from_mask
|
||||||
|
|
||||||
LAMA_MODEL_URL = os.environ.get(
|
LAMA_MODEL_URL = os.environ.get(
|
||||||
"LAMA_MODEL_URL",
|
"LAMA_MODEL_URL",
|
||||||
@ -13,7 +14,16 @@ LAMA_MODEL_URL = os.environ.get(
|
|||||||
|
|
||||||
|
|
||||||
class LaMa:
|
class LaMa:
|
||||||
def __init__(self, device):
|
def __init__(self, crop_trigger_size: List[int], crop_size: List[int], device):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
crop_trigger_size: h, w
|
||||||
|
crop_size: h, w
|
||||||
|
device:
|
||||||
|
"""
|
||||||
|
self.crop_trigger_size = crop_trigger_size
|
||||||
|
self.crop_size = crop_size
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
if os.environ.get("LAMA_MODEL"):
|
if os.environ.get("LAMA_MODEL"):
|
||||||
@ -32,6 +42,63 @@ class LaMa:
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(self, image, mask):
|
def __call__(self, image, mask):
|
||||||
|
"""
|
||||||
|
image: [C, H, W] RGB
|
||||||
|
mask: [1, H, W]
|
||||||
|
return: BGR IMAGE
|
||||||
|
"""
|
||||||
|
area = image.shape[1] * image.shape[2]
|
||||||
|
if area < self.crop_trigger_size[0] * self.crop_trigger_size[1]:
|
||||||
|
return self._run(image, mask)
|
||||||
|
|
||||||
|
print("Trigger crop image")
|
||||||
|
boxes = boxes_from_mask(mask)
|
||||||
|
crop_result = []
|
||||||
|
for box in boxes:
|
||||||
|
crop_image, crop_box = self._run_box(image, mask, box)
|
||||||
|
crop_result.append((crop_image, crop_box))
|
||||||
|
|
||||||
|
image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)[:, :, ::-1]
|
||||||
|
for crop_image, crop_box in crop_result:
|
||||||
|
x1, y1, x2, y2 = crop_box
|
||||||
|
image[y1:y2, x1:x2, :] = crop_image
|
||||||
|
return image
|
||||||
|
|
||||||
|
def _run_box(self, image, mask, box):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: [C, H, W] RGB
|
||||||
|
mask: [1, H, W]
|
||||||
|
box: [left,top,right,bottom]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
BGR IMAGE
|
||||||
|
"""
|
||||||
|
box_h = box[3] - box[1]
|
||||||
|
box_w = box[2] - box[0]
|
||||||
|
cx = (box[0] + box[2]) // 2
|
||||||
|
cy = (box[1] + box[3]) // 2
|
||||||
|
crop_h, crop_w = self.crop_size
|
||||||
|
img_h, img_w = image.shape[1:]
|
||||||
|
|
||||||
|
# TODO: when box_w > crop_w, add some margin around?
|
||||||
|
w = max(crop_w, box_w)
|
||||||
|
h = max(crop_h, box_h)
|
||||||
|
|
||||||
|
l = max(cx - w // 2, 0)
|
||||||
|
t = max(cy - h // 2, 0)
|
||||||
|
r = min(cx + w // 2, img_w)
|
||||||
|
b = min(cy + h // 2, img_h)
|
||||||
|
|
||||||
|
crop_img = image[:, t:b, l:r]
|
||||||
|
crop_mask = mask[:, t:b, l:r]
|
||||||
|
|
||||||
|
print(f"Apply zoom in size width x height: {crop_img.shape}")
|
||||||
|
|
||||||
|
return self._run(crop_img, crop_mask), [l, t, r, b]
|
||||||
|
|
||||||
|
def _run(self, image, mask):
|
||||||
"""
|
"""
|
||||||
image: [C, H, W] RGB
|
image: [C, H, W] RGB
|
||||||
mask: [1, H, W]
|
mask: [1, H, W]
|
||||||
@ -51,5 +118,5 @@ class LaMa:
|
|||||||
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy()
|
||||||
cur_res = cur_res[0:origin_height, 0:origin_width, :]
|
cur_res = cur_res[0:origin_height, 0:origin_width, :]
|
||||||
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
||||||
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB)
|
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
||||||
return cur_res
|
return cur_res
|
||||||
|
BIN
lama_cleaner/tests/mask.jpg
Normal file
BIN
lama_cleaner/tests/mask.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 11 KiB |
15
lama_cleaner/tests/test_boxes_from_mask.py
Normal file
15
lama_cleaner/tests/test_boxes_from_mask.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from lama_cleaner.helper import boxes_from_mask
|
||||||
|
|
||||||
|
|
||||||
|
def test_boxes_from_mask():
|
||||||
|
mask = cv2.imread("mask.jpg", cv2.IMREAD_GRAYSCALE)
|
||||||
|
mask = mask[:, :, np.newaxis]
|
||||||
|
mask = (mask / 255).transpose(2, 0, 1)
|
||||||
|
boxes = boxes_from_mask(mask)
|
||||||
|
print(boxes)
|
||||||
|
|
||||||
|
|
||||||
|
test_boxes_from_mask()
|
13
main.py
13
main.py
@ -97,12 +97,18 @@ def get_args_parser():
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--port", default=8080, type=int)
|
parser.add_argument("--port", default=8080, type=int)
|
||||||
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
|
parser.add_argument("--model", default="lama", choices=["lama", "ldm"])
|
||||||
|
parser.add_argument("--crop-trigger-size", default="2042,2042",
|
||||||
|
help="If image size large then crop-trigger-size, "
|
||||||
|
"crop each area from original image to do inference."
|
||||||
|
"Mainly for performance and memory reasons"
|
||||||
|
"Only for lama")
|
||||||
|
parser.add_argument("--crop-size", default="512,512")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--ldm-steps",
|
"--ldm-steps",
|
||||||
default=50,
|
default=50,
|
||||||
type=int,
|
type=int,
|
||||||
help="Steps for DDIM sampling process."
|
help="Steps for DDIM sampling process."
|
||||||
"The larger the value, the better the result, but it will be more time-consuming",
|
"The larger the value, the better the result, but it will be more time-consuming",
|
||||||
)
|
)
|
||||||
parser.add_argument("--device", default="cuda", type=str)
|
parser.add_argument("--device", default="cuda", type=str)
|
||||||
parser.add_argument("--debug", action="store_true")
|
parser.add_argument("--debug", action="store_true")
|
||||||
@ -115,8 +121,11 @@ def main():
|
|||||||
args = get_args_parser()
|
args = get_args_parser()
|
||||||
device = torch.device(args.device)
|
device = torch.device(args.device)
|
||||||
|
|
||||||
|
crop_trigger_size = [int(it) for it in args.crop_trigger_size.split(",")]
|
||||||
|
crop_size = [int(it) for it in args.crop_size.split(",")]
|
||||||
|
|
||||||
if args.model == "lama":
|
if args.model == "lama":
|
||||||
model = LaMa(device)
|
model = LaMa(crop_trigger_size=crop_trigger_size, crop_size=crop_size, device=device)
|
||||||
elif args.model == "ldm":
|
elif args.model == "ldm":
|
||||||
model = LDM(device, steps=args.ldm_steps)
|
model = LDM(device, steps=args.ldm_steps)
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user