From 87f54bb87e9607e0132508f5c3b30f72270ee9ff Mon Sep 17 00:00:00 2001 From: Qing Date: Thu, 11 May 2023 21:51:58 +0800 Subject: [PATCH] wip: controlnet --- lama_cleaner/const.py | 7 ++- lama_cleaner/model/controlnet.py | 35 ++++++++---- lama_cleaner/model/sd.py | 4 +- lama_cleaner/model_manager.py | 17 +++++- lama_cleaner/parse_args.py | 5 +- lama_cleaner/server.py | 11 ++++ lama_cleaner/tests/mask.png | Bin 0 -> 7916 bytes lama_cleaner/tests/test_controlnet.py | 60 ++++++++++++++++++-- lama_cleaner/tests/test_instruct_pix2pix.py | 4 +- requirements.txt | 3 +- 10 files changed, 117 insertions(+), 29 deletions(-) create mode 100644 lama_cleaner/tests/mask.png diff --git a/lama_cleaner/const.py b/lama_cleaner/const.py index 11148df..1b2b7e7 100644 --- a/lama_cleaner/const.py +++ b/lama_cleaner/const.py @@ -53,8 +53,13 @@ Run Stable Diffusion text encoder model on CPU to save GPU memory. """ SD_CONTROLNET_HELP = """ -Run Stable Diffusion 1.5 inpainting model with Canny ControlNet control. +Run Stable Diffusion inpainting model with ControlNet. You can switch control method in webui. """ +SD_CONTROLNET_CHOICES = [ + "control_v11p_sd15_canny", + "control_v11p_sd15_openpose", + "control_v11p_sd15_inpaint", +] SD_LOCAL_MODEL_HELP = """ Load Stable Diffusion 1.5 model(ckpt/safetensors) from local path. diff --git a/lama_cleaner/model/controlnet.py b/lama_cleaner/model/controlnet.py index 4eece59..cb0a963 100644 --- a/lama_cleaner/model/controlnet.py +++ b/lama_cleaner/model/controlnet.py @@ -4,9 +4,7 @@ import PIL.Image import cv2 import numpy as np import torch -from diffusers import ( - ControlNetModel, -) +from diffusers import ControlNetModel from loguru import logger from lama_cleaner.model.base import DiffusionInpaintModel @@ -75,7 +73,7 @@ def load_from_local_model( num_in_channels=4 if is_native_control_inpaint else 9, from_safetensors=local_model_path.endswith("safetensors"), device="cpu", - load_safety_checker=False + load_safety_checker=False, ) inpaint_pipe = pipe_class( @@ -92,7 +90,7 @@ def load_from_local_model( del pipe gc.collect() - return inpaint_pipe.to(torch_dtype) + return inpaint_pipe.to(torch_dtype=torch_dtype) class ControlNet(DiffusionInpaintModel): @@ -120,6 +118,7 @@ class ControlNet(DiffusionInpaintModel): torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 sd_controlnet_method = kwargs["sd_controlnet_method"] + self.sd_controlnet_method = sd_controlnet_method if sd_controlnet_method == "control_v11p_sd15_inpaint": from diffusers import StableDiffusionControlNetPipeline as PipeClass @@ -206,18 +205,30 @@ class ControlNet(DiffusionInpaintModel): output_type="np.array", ).images[0] else: - canny_image = cv2.Canny(image, 100, 200) - canny_image = canny_image[:, :, None] - canny_image = np.concatenate( - [canny_image, canny_image, canny_image], axis=2 - ) - canny_image = PIL.Image.fromarray(canny_image) + if "canny" in self.sd_controlnet_method: + canny_image = cv2.Canny(image, 100, 200) + canny_image = canny_image[:, :, None] + canny_image = np.concatenate( + [canny_image, canny_image, canny_image], axis=2 + ) + canny_image = PIL.Image.fromarray(canny_image) + control_image = canny_image + elif "openpose" in self.sd_controlnet_method: + from controlnet_aux import OpenposeDetector + + processor = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") + control_image = processor(image, hand_and_face=True) + else: + raise NotImplementedError( + f"{self.sd_controlnet_method} not implemented" + ) + mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L") image = PIL.Image.fromarray(image) output = self.model( image=image, - control_image=canny_image, + control_image=control_image, prompt=config.prompt, negative_prompt=config.negative_prompt, mask_image=mask_image, diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index e1ecece..fec7fc6 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -35,13 +35,13 @@ class CPUTextEncoderWrapper: def load_from_local_model(local_model_path, torch_dtype, disable_nsfw=True): from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( - load_pipeline_from_original_stable_diffusion_ckpt, + download_from_original_stable_diffusion_ckpt, ) from diffusers.pipelines.stable_diffusion import StableDiffusionInpaintPipeline logger.info(f"Converting {local_model_path} to diffusers pipeline") - pipe = load_pipeline_from_original_stable_diffusion_ckpt( + pipe = download_from_original_stable_diffusion_ckpt( local_model_path, num_in_channels=9, from_safetensors=local_model_path.endswith("safetensors"), diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index ac9abe3..74330bb 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -12,6 +12,7 @@ from lama_cleaner.model.mat import MAT from lama_cleaner.model.paint_by_example import PaintByExample from lama_cleaner.model.instruct_pix2pix import InstructPix2Pix from lama_cleaner.model.sd import SD15, SD2, Anything4, RealisticVision14 +from lama_cleaner.model.utils import torch_gc from lama_cleaner.model.zits import ZITS from lama_cleaner.model.opencv2 import OpenCV2 from lama_cleaner.schema import Config @@ -59,7 +60,7 @@ class ModelManager: def __call__(self, image, mask, config: Config): return self.model(image, mask, config) - def switch(self, new_name: str): + def switch(self, new_name: str, **kwargs): if new_name == self.name: return try: @@ -75,3 +76,17 @@ class ModelManager: self.name = new_name except NotImplementedError as e: raise e + + def switch_controlnet_method(self, control_method: str): + if not self.kwargs.get("sd_controlnet"): + return + if self.kwargs["sd_controlnet_method"] == control_method: + return + + del self.model + torch_gc() + + self.kwargs["sd_controlnet_method"] = control_method + self.model = self.init_model( + self.name, switch_mps_device(self.name, self.device), **self.kwargs + ) diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index 1d5abfe..1a3f48c 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -41,10 +41,7 @@ def parse_args(): parser.add_argument( "--sd-controlnet-method", default="control_v11p_sd15_inpaint", - choices=[ - "control_v11p_sd15_canny", - "control_v11p_sd15_inpaint", - ], + choices=SD_CONTROLNET_CHOICES, ) parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP) parser.add_argument( diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index c7ef623..8d1984c 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -436,6 +436,17 @@ def switch_model(): return f"ok, switch to {new_name}", 200 +@app.route("/controlnet_method", methods=["POST"]) +def switch_controlnet_method(): + new_method = request.form.get("method") + + try: + model.switch_controlnet_method(new_method) + except NotImplementedError: + return f"Failed switch to {new_method} not implemented", 500 + return f"Switch to {new_method}", 200 + + @app.route("/") def index(): return send_file(os.path.join(BUILD_DIR, "index.html")) diff --git a/lama_cleaner/tests/mask.png b/lama_cleaner/tests/mask.png new file mode 100644 index 0000000000000000000000000000000000000000..29cf20b5282d2a5684d75ef667dde6d8415f27e9 GIT binary patch literal 7916 zcmeHsXH-)`*X~J(kbodXnzRrC2nd8G5GkRD4k}e?(tEE0hNgg0r3-?JSOKZh2?~l* z1Vwu9RXRw`cX;c)Yu&&1-*?tJYxbG5XYZLgll|;xC(b}mgPQUZB>(`_nt0Xg0003+ z2mmJso7+CG9l!?WsHCd|0A)#3$2RA{XQVCux-I~O@BsiJ8UPN#F2W)J_+tQI*%|<_ zSpdN7m0M>h2fm25H`R2|)dd7W84jF-&;w+k1Ud5pxdfb9JM*UnVf&AK9dh|!I#2+3 z;0(b2rDG1ZXBQs4&&K?-B`bjZ*J1(m-`WtMfb75W*|4;pKrOIAc;U@_0f36(?1lh2 zxh$YHGv}M8ex|xQSX)nb5gR+tJN6<$?p|l30J$J6D7xGG*&u`5-8_7;LGq}7XkbD4 ztXmX?{D+F4t31k7*8r*F>0^()DuNNgpcE*PNTi&Podfo|s`|f%gHQ4(CqF+gtf**U zV4z5#govk)qo|mytgI+TTvS|K7}OB<4fgP}2@>}3<@u+Pf7?;D_qFwL_VRP~^gy23 zwYlTz?GqW{@81xi-SDGC5!ch^)^x)}snvv7W$ zZp1uZ6yB}Uy`Afj^D;I5W%|YYPRZBthKl91&mxq$Uvt?T{C)9AGWk0mRoRCU?7h;V z(*|+tar=;4M11 zs5${g`(aA}3##>y$Kr_aWB_WcpvDhH1H%9voCQLUBYuCis7iq0*pN^N4-N`9J-C<$ zsyUIVBCSEI5TX3%4B%OJC%|Ek00h!{?uH)qP^^V7ueu&TOZcir|1N1ViH6!z!4o;>gyjys3!tvc5Qu8R~$9PX|> zdB`l91G@Ufj`eFlKHmN@pCCP+yk3JC$dmC_ZZDuon&Xv45|{u|tWt{`5!*a#&tak9 z*>o`9`qtr_T4~?Jkk$o-CbWXnVNiKMlipD0zoLF5*btTi zn7+8!Pky+w^w=+aq=7l#_(AFF_2fn7Oq+ zSf&zc$suh8C7$GD=J6y4O-4}6%AzA8|9U8i#}!YNlM6gLtukw}`<5Dct(iVHl{nsb z_winTiADWv;IcTs#k?Pbl4G&xRYFTmk?xc4`@JtT_^ozrkfQO0XzK+kbV>7RNN`D> z1^#a57^f!iUUsVj5;phoPFsAt3+cPvbM%L>qX6&F^K!%f=MYV-?V3rlW2#@e^1K^A z=^dT?D$>=?RlL_!p09prp=rH2-4^LcIz4IpR27>}k0^+>zH48|^4E2wy0|@>Pc{|n zP&n=>F(C3KhKiY9Uv4uxP032>dcN6#pE2`Yd!l68vERfCjSFQD)E+X$H8Xb{q~A`1 zC2}b~Wl;=oNQ;?)3@xQSjJQIDu}{K%zEAyFddM*NSFP{DES5AIhYBr2BfZao=hbfY zn9SCZGH0wlP}>cneNpYQQUg~@EN-*kN>j^|V#JP@tAF)Jj#P)=a-Zh+h?zD$xFA%6 zOV+0_TiWxfjG{5y&x4gB1U+udYuYHCmf&lXp!jpLB`WWY`+DQPrL76)rAoYcrIX>C z*5e!+Fsxzig5-puua|mv)RzZJ^FuSrzx>`pF1GvrkmsK(`;|KPm?RE{yY@#HeiY#V zCAmb7YkZEYuDF9MYTL~|y^>9nv2;A0^82Gay$^aQFZ7eY>{bo{j zh2xK$GTX0OuSHGb9~vX3sr8=9nZ?^#8Two)V46z?Tp+Qp-k}JW7-59yc|~@bZ!tN> z9bTWQrSaLIrt>99>(N^fo-Rdt+$HNE8F}cklUKNV$gAeMAjitvu{=(h)NHf>xzKLV zrWSoZe+n@DLda3eh5s0DM*E(@rqM!+34`$<_Aq9CVHIU&7dp7 zP_Lrq(?jj%&YO$B1VJJB5;JEPfq=bc4n4$TKF_Jizw8Rr zA z#>nMzLOA|L*&10gB%db!%w-j~iDad04c)Dz-u+44BkIyK6_;0POf9DWGabQl}|^3+@wHHZ8Fe>oyzb_uI(i+beAD zvDa*VQ(hLvM!EcvNJd4UR|=q29Kj;bMH6k#I1-@vcQUF-J5xk+!AJKz$7`6TZXWhG zyNnl7)^f_+Oxui;9w*S7G*#i><2m@jrksyhHGj_^FJ=+^cVo(jWwy;2XB3SGj5a4r z&%=pByVOJigltwT6M?ITm>Z{!&22dl6v}U)3GD^DqPCLne&k)_kK`stUT1f?$?;wl zF0pnAKXvc}1YDTEGLT5KP-aszR{?%i|GLsNW&v4l2ks1!%=C2cU@xhdWZ*4i`bb|U z2CHswc7x+xDCdM&8F_bfQB?{2eRzU)vk~8WxaxZxI+5(=Z`v2UH`_k1SY@4$(v)Lq ze>cvh;J=*qe)Y4K;?CpW!5bVY{+P!!JWGF7{;tVX0tGtBjC?}6+2Uq2Epn#j%@wdu{+n9oQBmZqVQkk_k^?-~o|Z!LTplqNVOVNZLKgA@3%_~h znI0Zi=|(D5a1Pw`7&jqVWYIH8GpAE{Dq><-HNPsx-=g1tst)zi=a~+DoJu%v9nQmQP&&kyH>cx>a$_A5;69IC0Ct)A4piR zP0J>8e}Q5~Hpc=d(qnY9Xc@&dQvtax*xI^pj%&!vw-vUIwuM_;TLoXCcwot$H^9-a zqXdKp-;%qR)&wBdS0`78Dq?^S-^Kb`s{l@CPEPyrvABO2>AX%bx=69%wJPDr({`5UyqNFCi#uS8h7=KQnjnD>qH^C> zQJiMW^?~b_W44-XJ z*6{ z3aA*(xMFrE{AAZ8g}#+|GlGzcg67^~^G3HS_02Gv93CaWCf|5Wt=5Tq&3s9Cs%6Gf zwowFQh)Tu*XEweLD})_Z;vH3NzI76lU8?KeGP;rx9{%Bq{Rri*O2Y&GRcTok6g=wE z096#z6;tHqF}E`QOJ2fvKkP=^L*^)RHV{j{Xx#}Ao`+eRu%jDR9Nxwes^cq{Efv-~w01 zCPDYf)okz^npcy3Enxl%n=Wkol#vcIFqO;uLERBu*0Fxev(Y|7#9}!#*o^1|vMJSv zrM|VcRcz)^d>v(Xt#Ua|O@#n6kbFHBL|aXBE&RA9crM+fvp^}T@PKQ$-cwp~N?iaC zcq~4<{PHzBr`u?a_m<8cNmQj2&2otoO?Y1OlB6Q@>U{l3G8L3W$*X2HC@qKAn=0KdZkKyzy8Qx^cbuO^i$b z`|RABKXogu?N$BSVg4){=|Gtir7a3yhq(*`w~O0p=sVoeX4(dGE&y9UrShH*S=&zN8f1uvf>dYEo1B?EW_Cc*%Y495hF_(cHqW8W2xK!-eriR$0ZWmHm}Ou4Gylo7zU;mdjfS}}?859wnmO%#^AdTOP zGPGp}#2^`>Rvs9Z*o*OH14Y&i97>dN7d3$@vlzz(G2+X2V?C7CSgLGE1dqdRYLZ85WXi2rFfl!)b5;=MkI_D3BlQ|GWFH^I!d??OICnj zX(u5gDO1^;B59VK%}%j zBpE0mkTJf|K8u~%66LpOXB$bUM_o_l??%Iv-$J;r!GUMQWQ>_$p(+%vIU-=!#tDg~ ztqk6Fh>QCO5{@?z6%<6K@lQHoo%@>}6hC$G8OL>c!svkI-QuAtR}(LE!7qZp4)P+K z6K#A!nuI(Qa|m+5Y$joIe7l#$#mA>Kq;99;()?4+EOS&faD9fp&yXc7iBsOrzzZaL zMIZrG>ZCGo{MyP@8ItUW^}0+BmLt@DvCgXES10a~#&p2kMCYlT?C13Gyz5XLv(wi^ zjSSU43tEbb%3Rj*^4+T|jN1BDbRYiM0DLy4>bvVQOUHry{yoFbmG2rFnrE zEz)J9^Zyi$C^CGZ-qHa5HLp0Y&~Ui?uE>3)n%ou=A_qHBQK!dU&e7xMcM0hXRk&>F`Ol`<02q!Xy%Jsq7b<(ko6Y?vAed6rs$CTLkl z?=W^mTws^627|Np;r=>T>s`|KLyH=(XJ*aejY~rnuXav^@~M`tkwr!qx(f-pguBrU zgN29eoydKXDbu5z-hDc+${@_2GLJE=M8K-)ua1(FuSn$V{bO#vY5@8Rdwy*1#J+!` zA$U8|H@q_BO9C6e`2Zu~8CDSo)}T~@*Grh;1oI0I!})Y`E=KQodG#H`W3hMl+=BO$ z%o&B^kb8xRT{uYG-dmyRu_zHmIg+$vH+xDWcz#=!0X4+c_IpxZzW$_9sYdQqCnWw%w zDI4T4&I2WR`@W~#5;>=Xq{+l13aN@w48cSU|UmPUX$xJE=f?Qm|%z>>Hv&Af)nLh5w3i3#(!HuBm23;La2p&|qXR zDhu_XRJkg1jp?8>@0tlxwbd%$zI#9q4;-La9v}5es{dUFBNf{!V$nAQbYNd`<9$fq zH+_YUCrNH-LTJy!-bNN!-^N6p&2T!>-{A72Hvu3rRoTwAF)p&X{<4z(v|*1ipg8E+`RdwjdG+Uo zfy03@H(3pA8@n@AeqhaF2v85ff9ifz!RUC*16j!>NxAmNHYq1=JbBnve$F89`JL#DUED|ak&Myv3|$K~d}$UMCsLivMaSa5Il$?RmBpW=d$DjZT9w_2X`fM%CaO!* zo6%0~A0KR&4csU@w^8!O6;i0fis^CK0uxGM{?=<(2A6-oQyZxJhVakNTdO!{;8y29 z84OR1+u7<6WxJB9Bm@*k)u4riuI@IR?2YIi?X6Ms9fRc&uy}g6^Ul2oiD|ZFof=$W z7q11U0g~?nSNe;2f9kKoybFk*k7YsVJOGV$34a_#i9u&Dd;k=8 zP*!|Ej3A>Xm)>OCKFuyyAqW#zH@4$8nbp2cWz3-O@DDsUt(UE16pb z{U={tE?%Y9LXZbRcNCOx-?r)z^OrBDaXJ4B$N3S^Cvwnw}8Iv_bIOi-8G z(i5YjQiaB{Ov>0Nos)#K-cQ;6I1zqJeaPa9V%_DeWD^=dP8=zSo zFeiTxG5zR~mqs9t!^#+p3+0hQj~@2U4j`@$T2g5~8GosY6!JJ|CLH@_U7Xj|Wm2cV?lw;kC*UIJwz_e>_mx~u()78 zHZbH26);X#d8xKit18p#(Lxg$g}{I~wTb&rSN%1P?tNQcvbWo>1uFQZifTktrb;c>*oJ}yM z2sl#exe$G<%o?vGXV^aD^ArtaR?Xms3|3sa-bs+H43u&?fAf0;sW=50SuD_cGYFho z%-6nm>}&W?pUEzdR%Z~^MVo~fRJhNn3nw**t*loKSU!H=d+Zob4)J>%Fa4Y$;e2Nm z9BxfUlfY;i{gs!?W`P@xxE!2pMR^}U6j&P^KB055>+1}-)kY#`@fOJQu6vdPbSA77+Fsxg)1f3Trgr4xf givRz&he(243cpufbdvY;>_1#hH9gf5W$VcQ0Sb+j)&Kwi literal 0 HcmV?d00001 diff --git a/lama_cleaner/tests/test_controlnet.py b/lama_cleaner/tests/test_controlnet.py index 1e58875..d699e1b 100644 --- a/lama_cleaner/tests/test_controlnet.py +++ b/lama_cleaner/tests/test_controlnet.py @@ -1,5 +1,7 @@ import os +from lama_cleaner.const import SD_CONTROLNET_CHOICES + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" from pathlib import Path @@ -22,7 +24,10 @@ device = torch.device(device) @pytest.mark.parametrize("sampler", [SDSampler.uni_pc]) @pytest.mark.parametrize("cpu_textencoder", [True]) @pytest.mark.parametrize("disable_nsfw", [True]) -def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_nsfw): +@pytest.mark.parametrize("sd_controlnet_method", SD_CONTROLNET_CHOICES) +def test_runway_sd_1_5( + sd_device, strategy, sampler, cpu_textencoder, disable_nsfw, sd_controlnet_method +): if sd_device == "cuda" and not torch.cuda.is_available(): return if device == "mps" and not torch.backends.mps.is_available(): @@ -34,19 +39,20 @@ def test_runway_sd_1_5(sd_device, strategy, sampler, cpu_textencoder, disable_ns sd_controlnet=True, device=torch.device(sd_device), hf_access_token="", - sd_run_local=True, + sd_run_local=False, disable_nsfw=disable_nsfw, sd_cpu_textencoder=cpu_textencoder, + sd_controlnet_method=sd_controlnet_method, ) cfg = get_config(strategy, prompt="a fox sitting on a bench", sd_steps=sd_steps) cfg.sd_sampler = sampler - name = f"device_{sd_device}_{sampler}_cpu_textencoder_{cpu_textencoder}_disnsfw_{disable_nsfw}" + name = f"device_{sd_device}_{sampler}_cpu_textencoder_disable_nsfw" assert_equal( model, cfg, - f"sd_controlnet_{name}.png", + f"sd_controlnet_{sd_controlnet_method}_{name}.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", fx=1.2, @@ -68,11 +74,12 @@ def test_local_file_path(sd_device, sampler): sd_controlnet=True, device=torch.device(sd_device), hf_access_token="", - sd_run_local=True, + sd_run_local=False, disable_nsfw=True, sd_cpu_textencoder=False, cpu_offload=True, sd_local_model_path="/Users/cwq/data/models/sd-v1-5-inpainting.ckpt", + sd_controlnet_method="control_v11p_sd15_canny", ) cfg = get_config( HDStrategy.ORIGINAL, @@ -86,7 +93,48 @@ def test_local_file_path(sd_device, sampler): assert_equal( model, cfg, - f"sd_controlnet_local_model_{name}.png", + f"sd_controlnet_canny_local_model_{name}.png", + img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", + mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", + ) + + +@pytest.mark.parametrize("sd_device", ["cuda", "mps"]) +@pytest.mark.parametrize("sampler", [SDSampler.uni_pc]) +def test_local_file_path_controlnet_native_inpainting(sd_device, sampler): + if sd_device == "cuda" and not torch.cuda.is_available(): + return + if device == "mps" and not torch.backends.mps.is_available(): + return + + sd_steps = 1 if sd_device == "cpu" else 30 + model = ModelManager( + name="sd1.5", + sd_controlnet=True, + device=torch.device(sd_device), + hf_access_token="", + sd_run_local=False, + disable_nsfw=True, + sd_cpu_textencoder=False, + cpu_offload=True, + sd_local_model_path="/Users/cwq/data/models/v1-5-pruned-emaonly.safetensors", + sd_controlnet_method="control_v11p_sd15_inpaint", + ) + cfg = get_config( + HDStrategy.ORIGINAL, + prompt="a fox sitting on a bench", + sd_steps=sd_steps, + controlnet_conditioning_scale=1.0, + sd_strength=1.0 + ) + cfg.sd_sampler = sampler + + name = f"device_{sd_device}_{sampler}" + + assert_equal( + model, + cfg, + f"sd_controlnet_local_native_{name}.png", img_p=current_dir / "overture-creations-5sI6fQgYIuo.png", mask_p=current_dir / "overture-creations-5sI6fQgYIuo_mask.png", ) diff --git a/lama_cleaner/tests/test_instruct_pix2pix.py b/lama_cleaner/tests/test_instruct_pix2pix.py index 9780cf3..9778160 100644 --- a/lama_cleaner/tests/test_instruct_pix2pix.py +++ b/lama_cleaner/tests/test_instruct_pix2pix.py @@ -20,7 +20,7 @@ def test_instruct_pix2pix(disable_nsfw, cpu_offload): model = ModelManager(name="instruct_pix2pix", device=torch.device(device), hf_access_token="", - sd_run_local=True, + sd_run_local=False, disable_nsfw=disable_nsfw, sd_cpu_textencoder=False, cpu_offload=cpu_offload) @@ -45,7 +45,7 @@ def test_instruct_pix2pix_snow(disable_nsfw, cpu_offload): model = ModelManager(name="instruct_pix2pix", device=torch.device(device), hf_access_token="", - sd_run_local=True, + sd_run_local=False, disable_nsfw=disable_nsfw, sd_cpu_textencoder=False, cpu_offload=cpu_offload) diff --git a/requirements.txt b/requirements.txt index 8e4a89a..d64b969 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,5 @@ transformers==4.27.4 gradio piexif==1.1.3 safetensors -omegaconf \ No newline at end of file +omegaconf +controlnet-aux==0.0.3 \ No newline at end of file