change crop-size to crop-margin, to add more context for crop infer

This commit is contained in:
Sanster 2022-03-24 09:08:49 +08:00
parent 1207b6e291
commit d3f1ea2474
4 changed files with 14 additions and 14 deletions

View File

@ -22,7 +22,7 @@ python3 main.py --device=cuda --port=8080 --model=lama
- `--crop-trigger-size`: If image size large then crop-trigger-size, crop each area from original image to do inference. - `--crop-trigger-size`: If image size large then crop-trigger-size, crop each area from original image to do inference.
Mainly for performance and memory reasons on **very** large image.Default is 2042,2042 Mainly for performance and memory reasons on **very** large image.Default is 2042,2042
- `--crop-size`: Crop size for `--crop-trigger-size`. Default is 512,512. - `--crop-margin`: Margin around bounding box of painted stroke when crop mode triggered. Default is 256.
### Start server with LDM model ### Start server with LDM model

View File

@ -14,16 +14,16 @@ LAMA_MODEL_URL = os.environ.get(
class LaMa: class LaMa:
def __init__(self, crop_trigger_size: List[int], crop_size: List[int], device): def __init__(self, crop_trigger_size: List[int], crop_margin: int, device):
""" """
Args: Args:
crop_trigger_size: h, w crop_trigger_size: h, w
crop_size: h, w crop_margin:
device: device:
""" """
self.crop_trigger_size = crop_trigger_size self.crop_trigger_size = crop_trigger_size
self.crop_size = crop_size self.crop_margin = crop_margin
self.device = device self.device = device
if os.environ.get("LAMA_MODEL"): if os.environ.get("LAMA_MODEL"):
@ -79,12 +79,10 @@ class LaMa:
box_w = box[2] - box[0] box_w = box[2] - box[0]
cx = (box[0] + box[2]) // 2 cx = (box[0] + box[2]) // 2
cy = (box[1] + box[3]) // 2 cy = (box[1] + box[3]) // 2
crop_h, crop_w = self.crop_size
img_h, img_w = image.shape[1:] img_h, img_w = image.shape[1:]
# TODO: when box_w > crop_w, add some margin around? w = box_w + self.crop_margin * 2
w = max(crop_w, box_w) h = box_h + self.crop_margin * 2
h = max(crop_h, box_h)
l = max(cx - w // 2, 0) l = max(cx - w // 2, 0)
t = max(cy - h // 2, 0) t = max(cy - h // 2, 0)
@ -94,7 +92,7 @@ class LaMa:
crop_img = image[:, t:b, l:r] crop_img = image[:, t:b, l:r]
crop_mask = mask[:, t:b, l:r] crop_mask = mask[:, t:b, l:r]
print(f"Apply zoom in size width x height: {crop_img.shape}") print(f"box size: ({box_h},{box_w}) crop size: {crop_img.shape}")
return self._run(crop_img, crop_mask), [l, t, r, b] return self._run(crop_img, crop_mask), [l, t, r, b]

View File

@ -318,8 +318,10 @@ class LDM:
cx = box_x + box_w // 2 cx = box_x + box_w // 2
cy = box_y + box_h // 2 cy = box_y + box_h // 2
w = max(512, box_w) # w = max(512, box_w)
h = max(512, box_h) # h = max(512, box_h)
w = box_w + 512
h = box_h + 512
left = max(cx - w // 2, 0) left = max(cx - w // 2, 0)
top = max(cy - h // 2, 0) top = max(cy - h // 2, 0)

View File

@ -102,7 +102,8 @@ def get_args_parser():
"crop each area from original image to do inference." "crop each area from original image to do inference."
"Mainly for performance and memory reasons" "Mainly for performance and memory reasons"
"Only for lama") "Only for lama")
parser.add_argument("--crop-size", default="512,512") parser.add_argument("--crop-margin", type=int, default=256,
help="Margin around bounding box of painted stroke when crop mode triggered")
parser.add_argument( parser.add_argument(
"--ldm-steps", "--ldm-steps",
default=50, default=50,
@ -122,10 +123,9 @@ def main():
device = torch.device(args.device) device = torch.device(args.device)
crop_trigger_size = [int(it) for it in args.crop_trigger_size.split(",")] crop_trigger_size = [int(it) for it in args.crop_trigger_size.split(",")]
crop_size = [int(it) for it in args.crop_size.split(",")]
if args.model == "lama": if args.model == "lama":
model = LaMa(crop_trigger_size=crop_trigger_size, crop_size=crop_size, device=device) model = LaMa(crop_trigger_size=crop_trigger_size, crop_margin=args.crop_margin, device=device)
elif args.model == "ldm": elif args.model == "ldm":
model = LDM(device, steps=args.ldm_steps) model = LDM(device, steps=args.ldm_steps)
else: else: