From 43c9c22c7312dd39feac4e3783e9ec080fd64243 Mon Sep 17 00:00:00 2001 From: Sanster Date: Wed, 23 Mar 2022 10:02:01 +0800 Subject: [PATCH] add crop infor for lama --- README.md | 5 +- lama_cleaner/helper.py | 25 +++++++ lama_cleaner/lama/__init__.py | 73 ++++++++++++++++++++- lama_cleaner/tests/mask.jpg | Bin 0 -> 11048 bytes lama_cleaner/tests/test_boxes_from_mask.py | 15 +++++ main.py | 13 +++- 6 files changed, 125 insertions(+), 6 deletions(-) create mode 100644 lama_cleaner/tests/mask.jpg create mode 100644 lama_cleaner/tests/test_boxes_from_mask.py diff --git a/README.md b/README.md index a717f58..245ec53 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,10 @@ Install requirements: `pip3 install -r requirements.txt` python3 main.py --device=cuda --port=8080 --model=lama ``` +- `--crop-trigger-size`: If image size large then crop-trigger-size, crop each area from original image to do inference. + Mainly for performance and memory reasons on **very** large image.Default is 2042,2042 +- `--crop-size`: Crop size for `--crop-trigger-size`. Default is 512,512. + ### Start server with LDM model ```bash @@ -35,7 +39,6 @@ results than LaMa. |--------------|------|----| |![photo-1583445095369-9c651e7e5d34](https://user-images.githubusercontent.com/3998421/156923525-d6afdec3-7b98-403f-ad20-88ebc6eb8d6d.jpg)|![photo-1583445095369-9c651e7e5d34_cleanup_lama](https://user-images.githubusercontent.com/3998421/156923620-a40cc066-fd4a-4d85-a29f-6458711d1247.png)|![photo-1583445095369-9c651e7e5d34_cleanup_ldm](https://user-images.githubusercontent.com/3998421/156923652-0d06c8c8-33ad-4a42-a717-9c99f3268933.png)| - Blogs about diffusion models: - https://lilianweng.github.io/posts/2021-07-11-diffusion-models/ diff --git a/lama_cleaner/helper.py b/lama_cleaner/helper.py index 943d1be..9c6e986 100644 --- a/lama_cleaner/helper.py +++ b/lama_cleaner/helper.py @@ -1,5 +1,6 @@ import os import sys +from typing import List from urllib.parse import urlparse import cv2 @@ -80,3 +81,27 @@ def pad_img_to_modulo(img, mod): ((0, 0), (0, out_height - height), (0, out_width - width)), mode="symmetric", ) + + +def boxes_from_mask(mask: np.ndarray) -> List[np.ndarray]: + """ + Args: + mask: (1, h, w) 0~1 + + Returns: + + """ + height, width = mask.shape[1:] + _, thresh = cv2.threshold((mask.transpose(1, 2, 0) * 255).astype(np.uint8), 127, 255, 0) + contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + boxes = [] + for cnt in contours: + x, y, w, h = cv2.boundingRect(cnt) + box = np.array([x, y, x + w, y + h]).astype(np.int) + + box[::2] = np.clip(box[::2], 0, width) + box[1::2] = np.clip(box[1::2], 0, height) + boxes.append(box) + + return boxes diff --git a/lama_cleaner/lama/__init__.py b/lama_cleaner/lama/__init__.py index 3075e15..49f2596 100644 --- a/lama_cleaner/lama/__init__.py +++ b/lama_cleaner/lama/__init__.py @@ -1,10 +1,11 @@ import os +from typing import List import cv2 import torch import numpy as np -from lama_cleaner.helper import pad_img_to_modulo, download_model +from lama_cleaner.helper import pad_img_to_modulo, download_model, boxes_from_mask LAMA_MODEL_URL = os.environ.get( "LAMA_MODEL_URL", @@ -13,7 +14,16 @@ LAMA_MODEL_URL = os.environ.get( class LaMa: - def __init__(self, device): + def __init__(self, crop_trigger_size: List[int], crop_size: List[int], device): + """ + + Args: + crop_trigger_size: h, w + crop_size: h, w + device: + """ + self.crop_trigger_size = crop_trigger_size + self.crop_size = crop_size self.device = device if os.environ.get("LAMA_MODEL"): @@ -32,6 +42,63 @@ class LaMa: @torch.no_grad() def __call__(self, image, mask): + """ + image: [C, H, W] RGB + mask: [1, H, W] + return: BGR IMAGE + """ + area = image.shape[1] * image.shape[2] + if area < self.crop_trigger_size[0] * self.crop_trigger_size[1]: + return self._run(image, mask) + + print("Trigger crop image") + boxes = boxes_from_mask(mask) + crop_result = [] + for box in boxes: + crop_image, crop_box = self._run_box(image, mask, box) + crop_result.append((crop_image, crop_box)) + + image = (image.transpose(1, 2, 0) * 255).astype(np.uint8)[:, :, ::-1] + for crop_image, crop_box in crop_result: + x1, y1, x2, y2 = crop_box + image[y1:y2, x1:x2, :] = crop_image + return image + + def _run_box(self, image, mask, box): + """ + + Args: + image: [C, H, W] RGB + mask: [1, H, W] + box: [left,top,right,bottom] + + Returns: + BGR IMAGE + """ + box_h = box[3] - box[1] + box_w = box[2] - box[0] + cx = (box[0] + box[2]) // 2 + cy = (box[1] + box[3]) // 2 + crop_h, crop_w = self.crop_size + img_h, img_w = image.shape[1:] + + # TODO: when box_w > crop_w, add some margin around? + w = max(crop_w, box_w) + h = max(crop_h, box_h) + + l = max(cx - w // 2, 0) + t = max(cy - h // 2, 0) + r = min(cx + w // 2, img_w) + b = min(cy + h // 2, img_h) + + crop_img = image[:, t:b, l:r] + crop_mask = mask[:, t:b, l:r] + + print(f"Apply zoom in size width x height: {crop_img.shape}") + + return self._run(crop_img, crop_mask), [l, t, r, b] + + def _run(self, image, mask): """ image: [C, H, W] RGB mask: [1, H, W] @@ -51,5 +118,5 @@ class LaMa: cur_res = inpainted_image[0].permute(1, 2, 0).detach().cpu().numpy() cur_res = cur_res[0:origin_height, 0:origin_width, :] cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8") - cur_res = cv2.cvtColor(cur_res, cv2.COLOR_BGR2RGB) + cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR) return cur_res diff --git a/lama_cleaner/tests/mask.jpg b/lama_cleaner/tests/mask.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a2aec11caf09ebaab0fb7c35d8145cd5d0f7092d GIT binary patch literal 11048 zcmeHLbyQScyFX`=92x|Pp^E@yMF)fdMmnVgL_z^YP_ac2Q4tw~5(7cPKoKMq zix34x9Sj_0WbQs-px?Ld`*Ho_u65TvXU@!T$Fuj-`#HmWz#N3Q?W}CAAQ*-q8*oF+ zAt=oNeGOn7uqcmz>PRUOhZvt5BfRs`29mYs#% zJz{we^P%&Y@if+b6Lhl_<}gpCt+_eL$;Hvi*1>XCD4s*U5fLGPDJV27Cd%5y)YHS$ zi->;=nPs@qP1^ms4r1dRg>rx6&^qd`8Ql6M1OOIhQuO#BbNktIHH4B*@ItIz<-9nuLXja z1AN}k*K!rWdH_$vQQTGnYyv?jnTQxO7l7>m_KyvAH3e81;HrQqD_4Mx03Ha6wL!2p z1fdrt2Ca4nSOj2);Bb3{Rtn&Nu=Oh)0S0@axB5k!BOavyaB~pZ7O|NOz^|ghU647z zUoixKvV|qUfDeowCB}C3cbpX+w+iV8d}4YMgUk^=!CqL74Ze0x0LuX^M-H(@_66%; z9U@|!kTpRY8y6O`64}cH;2Y%VS$RPl#}X5?8u4sRfLF&vtwv}8Ke(&_ij^(ERsdH7 zMcE+zKwcazB4pOjfKKqLzEPIepdaW2A0HL#g6ILrh%Y0Dxgz=mK7xM=xxu~=848Ei zgKs}53>t!nP&7n=;$~4GM zv2Z8=%nbpQKl+J~KUiUo5?Li2JV*;!d2DuNzwbM$f*sApk3+FwJP4Wq<3RwefCguK z;sJ-DpluF%W5HM;!1H&}1fHY`In8oA#{)8s>@xm+mvBJgKfBi)9U|u(Z{~pKFenj> zA-v7XjuFDBW3<4dk8!{lV$^4MivlweAeGsEgW0|xyZ(|Kk!}3vu91M(`O?qV*bgl? z0e!`Ql|z7pVc+!=4Y*+dzJtG0|I+5{jypwS6(S;vb~xLvD+l}n$Sjm0y#;iPuW{AC*dhh!j?aBX*RexFY2XDVy)p5Sx%zBdlzxWN<3)_vo zhkc6efr#MuIkp!&2GAqyV{kwG$~{YJl=1%2k{VOf5cn< zRX+=fka_VD5Vlh!B4S;6@-#VO&VGpP$k{8WRzm5@>Sxs_)XUT>)jy~|oz=pu{nY!_ z->CPhmw=Ji=w0X<^cD0)bOX8rB7#-}x(R)87FD7zgZtc1y5A#WPRGAT1md55-y;B# zKIXd*&X1M(Spl&JvJPSmWd08wZulPA^L-W3__f)v4lAOuEj2=W9A!s4I$_+haS(1Umo;(5P(aW3z90xo}e z8X{j1uzu8Rmh=5C@1HUff0Gi)fRy=h26KW)p4(|Ihy2*{hhAqb{96Wf1v^oojew6R zpjRY2gn`^c{A~`ByOFE~JRV*%Yr}uzqqyZbd7LFq9?;HZDRI_CN=qVyFx{1yw=kp*rX?)C9Fbx1oDbFZ2kUK%>xWhz3nT(=ZIsh42Bm1U?0ygD=8O@J+Z2eh3f2 z8K@noeW+4YC8`$Hgu0FDMGc}} zqduV_G&@=lEs0h^>!Hojj%Y7*06Getg3dt~p-a(KAPcmk@1qCN6X-7(9EJ}=#Haw< z*kD{Sewav13MLP80CN&ki)q2!#|&X;m>Dc5Rurp*U5>TIdSC;w8?jl~B5XOf2HOg< z_9%7=hrXWu zxyduei|3W&wdD2Z&EP%8d!2WL56vgdXU^xxw}r2aua$3ftyHagtx0VqZL;$Yg}qPY$9#qZ&GgZ(sYUG z2GdH@w`S^QQD!w}pUsz>Z!*7Tfwr)=$g#L<$!qCqdBF0im6TPW)oCl5wT^X?^)(x; zjlE5QO~0*}t)J~FTbkWcyG?d2_8j)B><`!vJ19CtIW(+5t*~3MYsKT0vMa+@);U6s zwvM|TpE$`oMLJz_#yLAVA9Q}v(5*_pOy*8^8AEI+1lD z>#mTvNWP>CK3E@jp9&v_ucL3N@2B;)>kqA``I-Ce^?U7a;=kK}oNP!gB)<$W3@8j3 z3p5Pe6*wMb98?rEK{2NspiFMC*-*UUbFgFZiC|`kdk8g@CDbSMVi<2&a9B&YczAqx zSAWZB^2j}rv?%+i@@P1EZFGGMKX@tah*gNqj2((IiaQub-{`jSLOf4=SbRr< zV#2nBmx)%1$CEHgeo4*AlF6Hshc=mRDoKG;)~7V5N~dO|j;2|so!rc}Ie7D(boKPY z^p9KIw=`skWu#?{WLjsQ&LU(*W%Xz4XBTh9Y^7|yyG?7`{%tck{yDdD)pLt-8F_wr zx3+6+-?yE)BVfnfeBJybJ8?V1cJ>t*7o05QDNHOJ*|lOLJlX+YWs`?0>lXh{=(vVzJ`f;;%<3M<0|} zm(-TZl@=Yt9g92mqRg$V{kY!oQ{^J%x#izZgq;{X>2k8=)Y4O@Dnu*tPoqx9o*u9C zuDo}~;!J&&N>vF}keYiIIvacT)j6MYedq1ZH($`ZP*p8keW-@7CZ`szO{k^S1=S7K zd)9Ydw7b~cpx;n)N%hi+%aWH5UJsl4Z|CkZ!W)i@s{4Ly4$+9Ydds0YVPRVsky6rx3+U>XZ^ip z_Zqqkx~_H`cQ@U)xZl=e+jF~jWp7uXYu|%@@BYCD{tw0;hCG}ch#Q!Glm-g>oG08* zik^x-Eq%7+8Ff&1@XC*O${98FV`KVbEAC@^=ut35-J+9xSh|B$8OBYQ>YkgEYE1RIuB4H6(IeCR8irPB5OZAp5H@C2~vbM3cTeaHN%^egfe*WZu zz#z(o=$P2JjqwSITQV}UvbS!_DcrStPto3e`wx^JD?46(;^e7{^B1aXYU}DRHZ)&v zX>Du2ar0L9{hr>w{s#{ShK5I;kG^<0HvVez{fCdArasfZd}TfYFTki-p^$GBa$(TO zHxL1h!C}FPjXZc(7VzM4IB*WLvLP2h9PDfyNQ3zV{6l8X9DrC5)fhuFsbi%%3jdQ& zBlF(h5O@6zarfU4-~Supp1&dP{Tt%Gzaj4b8{+@B^FKuhHRanD5B_BY$#9*pWs%q%^Ol24&c*s6gAijJ?Ki0EGW#t!(^LLaKu(wXx>@tp1IOd#o)O}#1VBhBm zothpcnXuz)#Y{O^nlR!=cWA8B8sehMb20cnUo(=Zr=3V|T(6cfu_%1%q}WQy*;XS4<2abFdiIJqe})hkshvUP7Hr@Wm+O$EZJ~qleI^#^(1LaOnv%8 zicmC8w?FPwpMqRL^4_%MQ_ANWLOG@ChMw>{Th4t_nF|Ut#Z9VX9hsd8Y)oj&!O+27Gc5Zj%U{^=l+|Wi92qcc)OM+=S?}7zgxJydc37rL4CGX7;44*d z2p+e3G1v4*{stjRl6@IHHwy?BYKo}m)`$z-_*anpCPdK`iykWjKY0un)|k}UH~Y zjpR0Z=i7uSDo4~h=8uHR2)X#K660gtUxVcR2Nx39iBDu9<=9UCL z)GNNFbL5I^x8_KK^D%E6U`^K4P}BM(^~H}^R48p+SlUyd_N&pglFaVwyrV0+mB>D+ zPD#n8rLOH&AqSy#zo=DTGJ_3D*g8KDO07gu(cdhcNf{EfB<+Y|u9VKGefQ({ww8FD zd<<7t-{|y?FqRxAbqdTg-vYB)|2w)C>zWzyy~E`L5k}Cf!9%G%;kMNAuQ@ZR#)DI# zDW5KrX*LEfApV`1k**M&Yp>MG`AKBcgC42a}lSU`;%=X)gOei4=?=@4=gz*>W6kt3D z?l%5;AZezzku${aa>iq8T2{2Ol;zWsnv=DKT004SZHaI6#)4@teik5&3rs;(#|EFZ zsclV#Q=};p0#^m-Rnr#zA64ww-^jb(58&yfMDk_ZO0Q?(P`4(gSKcyp^GPisJ$O08 zr9@wLhAwS+FEym`;r*7lr;>cvgDc{94ZR+C;`w+9o5*CP++kH8VxiX}1_Qf(#|Y{Q z6B^-Sh?RiMb(N-eeVPgJF!bu^nImEW{HU8kq~rHh&N=EFt_syr5d;m%PX- zjm(8b*&jPE?C}1BcwvX=AH)kg+5RA2*opNA@&DWTXBPTTz(OhP4To1ZkZyB*rfLtOhnLcTKfmH%sb>y zm~-@Nsa(12@ODCg$I6~J0-|<>)8&jAi-V^lQ=2RqV&Ri1ZJn01rD^Y^xlY7iYIC`B z!ByfienPjF30uOp58R#7SfGWhTyXq){;%&NNV|3%P^OtVC2gVmYuA#t=IqD}eV(~vWoq)F=mU(x1v$#k^uDC)#D1yvLt#Ez zq`ngjk;8*(T=9-GtV>VC@3lUcrBuwvu$UB^Debye;W>OYSfNRJ?<3AiJJ#nAe?)1c z?a9*1;ExS=pb9kI8W^GX+wPjq0sG_}hoxut(HzwO=P%U(K%U*88@UslldrISifv{-n)j~`$=)V X?{|F;92Nal9{(qS+xWj&?mhf3I)(uT literal 0 HcmV?d00001 diff --git a/lama_cleaner/tests/test_boxes_from_mask.py b/lama_cleaner/tests/test_boxes_from_mask.py new file mode 100644 index 0000000..3faa4c6 --- /dev/null +++ b/lama_cleaner/tests/test_boxes_from_mask.py @@ -0,0 +1,15 @@ +import cv2 +import numpy as np + +from lama_cleaner.helper import boxes_from_mask + + +def test_boxes_from_mask(): + mask = cv2.imread("mask.jpg", cv2.IMREAD_GRAYSCALE) + mask = mask[:, :, np.newaxis] + mask = (mask / 255).transpose(2, 0, 1) + boxes = boxes_from_mask(mask) + print(boxes) + + +test_boxes_from_mask() diff --git a/main.py b/main.py index 18968a1..a50df20 100644 --- a/main.py +++ b/main.py @@ -97,12 +97,18 @@ def get_args_parser(): parser = argparse.ArgumentParser() parser.add_argument("--port", default=8080, type=int) parser.add_argument("--model", default="lama", choices=["lama", "ldm"]) + parser.add_argument("--crop-trigger-size", default="2042,2042", + help="If image size large then crop-trigger-size, " + "crop each area from original image to do inference." + "Mainly for performance and memory reasons" + "Only for lama") + parser.add_argument("--crop-size", default="512,512") parser.add_argument( "--ldm-steps", default=50, type=int, help="Steps for DDIM sampling process." - "The larger the value, the better the result, but it will be more time-consuming", + "The larger the value, the better the result, but it will be more time-consuming", ) parser.add_argument("--device", default="cuda", type=str) parser.add_argument("--debug", action="store_true") @@ -115,8 +121,11 @@ def main(): args = get_args_parser() device = torch.device(args.device) + crop_trigger_size = [int(it) for it in args.crop_trigger_size.split(",")] + crop_size = [int(it) for it in args.crop_size.split(",")] + if args.model == "lama": - model = LaMa(device) + model = LaMa(crop_trigger_size=crop_trigger_size, crop_size=crop_size, device=device) elif args.model == "ldm": model = LDM(device, steps=args.ldm_steps) else: