diff --git a/lama_cleaner/model/controlnet.py b/lama_cleaner/model/controlnet.py index b9c937d..4eece59 100644 --- a/lama_cleaner/model/controlnet.py +++ b/lama_cleaner/model/controlnet.py @@ -42,23 +42,43 @@ NAMES_MAP = { "realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting", } +NATIVE_NAMES_MAP = { + "sd1.5": "runwayml/stable-diffusion-v1-5", + "anything4": "andite/anything-v4.0", + "realisticVision1.4": "SG161222/Realistic_Vision_V1.4", +} -def load_from_local_model(local_model_path, torch_dtype, controlnet): + +def make_inpaint_condition(image, image_mask): + """ + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + """ + image = image.astype(np.float32) / 255.0 + image[image_mask[:, :, -1] > 128] = -1.0 # set as masked pixel + image = np.expand_dims(image, 0).transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return image + + +def load_from_local_model( + local_model_path, torch_dtype, controlnet, pipe_class, is_native_control_inpaint +): from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( - load_pipeline_from_original_stable_diffusion_ckpt, + download_from_original_stable_diffusion_ckpt, ) - from .pipeline import StableDiffusionControlNetInpaintPipeline logger.info(f"Converting {local_model_path} to diffusers controlnet pipeline") - pipe = load_pipeline_from_original_stable_diffusion_ckpt( + pipe = download_from_original_stable_diffusion_ckpt( local_model_path, - num_in_channels=9, + num_in_channels=4 if is_native_control_inpaint else 9, from_safetensors=local_model_path.endswith("safetensors"), device="cpu", + load_safety_checker=False ) - inpaint_pipe = StableDiffusionControlNetInpaintPipeline( + inpaint_pipe = pipe_class( vae=pipe.vae, text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer, @@ -81,9 +101,6 @@ class ControlNet(DiffusionInpaintModel): min_size = 512 def init_model(self, device: torch.device, **kwargs): - from .pipeline import StableDiffusionControlNetInpaintPipeline - - model_id = NAMES_MAP[kwargs["name"]] fp16 = not kwargs.get("no_half", False) model_kwargs = { @@ -102,17 +119,35 @@ class ControlNet(DiffusionInpaintModel): use_gpu = device == torch.device("cuda") and torch.cuda.is_available() torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + sd_controlnet_method = kwargs["sd_controlnet_method"] + + if sd_controlnet_method == "control_v11p_sd15_inpaint": + from diffusers import StableDiffusionControlNetPipeline as PipeClass + + self.is_native_control_inpaint = True + else: + from .pipeline import StableDiffusionControlNetInpaintPipeline as PipeClass + + self.is_native_control_inpaint = False + + if self.is_native_control_inpaint: + model_id = NATIVE_NAMES_MAP[kwargs["name"]] + else: + model_id = NAMES_MAP[kwargs["name"]] + controlnet = ControlNetModel.from_pretrained( - f"lllyasviel/sd-controlnet-canny", torch_dtype=torch_dtype + f"lllyasviel/{sd_controlnet_method}", torch_dtype=torch_dtype ) if kwargs.get("sd_local_model_path", None): self.model = load_from_local_model( kwargs["sd_local_model_path"], torch_dtype=torch_dtype, controlnet=controlnet, + pipe_class=PipeClass, + is_native_control_inpaint=self.is_native_control_inpaint, ) else: - self.model = StableDiffusionControlNetInpaintPipeline.from_pretrained( + self.model = PipeClass.from_pretrained( model_id, controlnet=controlnet, revision="fp16" if use_gpu and fp16 else "main", @@ -156,28 +191,45 @@ class ControlNet(DiffusionInpaintModel): img_h, img_w = image.shape[:2] - canny_image = cv2.Canny(image, 100, 200) - canny_image = canny_image[:, :, None] - canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2) - canny_image = PIL.Image.fromarray(canny_image) - mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L") - image = PIL.Image.fromarray(image) + if self.is_native_control_inpaint: + control_image = make_inpaint_condition(image, mask) + output = self.model( + prompt=config.prompt, + image=control_image, + height=img_h, + width=img_w, + num_inference_steps=config.sd_steps, + guidance_scale=config.sd_guidance_scale, + controlnet_conditioning_scale=config.controlnet_conditioning_scale, + negative_prompt=config.negative_prompt, + generator=torch.manual_seed(config.sd_seed), + output_type="np.array", + ).images[0] + else: + canny_image = cv2.Canny(image, 100, 200) + canny_image = canny_image[:, :, None] + canny_image = np.concatenate( + [canny_image, canny_image, canny_image], axis=2 + ) + canny_image = PIL.Image.fromarray(canny_image) + mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L") + image = PIL.Image.fromarray(image) - output = self.model( - image=image, - control_image=canny_image, - prompt=config.prompt, - negative_prompt=config.negative_prompt, - mask_image=mask_image, - num_inference_steps=config.sd_steps, - guidance_scale=config.sd_guidance_scale, - output_type="np.array", - callback=self.callback, - height=img_h, - width=img_w, - generator=torch.manual_seed(config.sd_seed), - controlnet_conditioning_scale=config.controlnet_conditioning_scale, - ).images[0] + output = self.model( + image=image, + control_image=canny_image, + prompt=config.prompt, + negative_prompt=config.negative_prompt, + mask_image=mask_image, + num_inference_steps=config.sd_steps, + guidance_scale=config.sd_guidance_scale, + output_type="np.array", + callback=self.callback, + height=img_h, + width=img_w, + generator=torch.manual_seed(config.sd_seed), + controlnet_conditioning_scale=config.controlnet_conditioning_scale, + ).images[0] output = (output * 255).round().astype("uint8") output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index 5af1cb2..f1d7ba1 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -38,6 +38,14 @@ def parse_args(): "--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP ) parser.add_argument("--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP) + parser.add_argument( + "--sd-controlnet-method", + default="control_v11p_sd15_inpaint", + choices=[ + "control_v11p_sd15_canny", + "control_v11p_sd15_inpaint", + ], + ) parser.add_argument("--sd-local-model-path", default=None, help=SD_LOCAL_MODEL_HELP) parser.add_argument( "--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP @@ -86,7 +94,7 @@ def parse_args(): "--interactive-seg-model", default="vit_l", choices=AVAILABLE_INTERACTIVE_SEG_MODELS, - help=INTERACTIVE_SEG_MODEL_HELP + help=INTERACTIVE_SEG_MODEL_HELP, ) parser.add_argument( "--interactive-seg-device", @@ -168,11 +176,11 @@ def parse_args(): if args.config_installer: if args.installer_config is None: parser.error( - f"args.config_installer==True, must set args.installer_config to store config file" + "args.config_installer==True, must set args.installer_config to store config file" ) from lama_cleaner.web_config import main - logger.info(f"Launching installer web config page") + logger.info("Launching installer web config page") main(args.installer_config) exit() @@ -194,10 +202,6 @@ def parse_args(): "torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation" ) - if args.sd_controlnet: - if args.model not in SD15_MODELS: - logger.warning(f"--sd_controlnet only support {SD15_MODELS}") - if args.sd_local_model_path and args.model == "sd1.5": if not os.path.exists(args.sd_local_model_path): parser.error( diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index 8e93424..a36bb91 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -537,6 +537,7 @@ def main(args): model = ModelManager( name=args.model, sd_controlnet=args.sd_controlnet, + sd_controlnet_method=args.sd_controlnet_method, device=device, no_half=args.no_half, hf_access_token=args.hf_access_token, diff --git a/requirements.txt b/requirements.txt index deb8aa0..718bc44 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,7 @@ pydantic rich loguru yacs -diffusers[torch]==0.14.0 +diffusers==0.16.1 transformers==4.27.4 gradio piexif==1.1.3