diff --git a/lama_cleaner/__init__.py b/lama_cleaner/__init__.py index 399fff8..a1ce1a1 100644 --- a/lama_cleaner/__init__.py +++ b/lama_cleaner/__init__.py @@ -1,11 +1,18 @@ +import os + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + import warnings + warnings.simplefilter("ignore", UserWarning) from lama_cleaner.parse_args import parse_args + def entry_point(): args = parse_args() # To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers # https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18 from lama_cleaner.server import main + main(args) diff --git a/lama_cleaner/app/src/App.tsx b/lama_cleaner/app/src/App.tsx index 1c2e4c6..f23ff29 100644 --- a/lama_cleaner/app/src/App.tsx +++ b/lama_cleaner/app/src/App.tsx @@ -7,6 +7,7 @@ import Workspace from './components/Workspace' import { enableFileManagerState, fileState, + isControlNetState, isDisableModelSwitchState, isEnableAutoSavingState, toastState, @@ -17,6 +18,7 @@ import useHotKey from './hooks/useHotkey' import { getEnableAutoSaving, getEnableFileManager, + getIsControlNet, getIsDisableModelSwitch, isDesktop, } from './adapters/inpainting' @@ -37,6 +39,7 @@ function App() { const setIsDisableModelSwitch = useSetRecoilState(isDisableModelSwitchState) const setEnableFileManager = useSetRecoilState(enableFileManagerState) const setIsEnableAutoSavingState = useSetRecoilState(isEnableAutoSavingState) + const setIsControlNet = useSetRecoilState(isControlNetState) // Set Input Image useEffect(() => { @@ -75,10 +78,17 @@ function App() { setIsEnableAutoSavingState(isEnabled === 'true') } fetchData3() + + const fetchData4 = async () => { + const isEnabled = await getIsControlNet().then(res => res.text()) + setIsControlNet(isEnabled === 'true') + } + fetchData4() }, [ setEnableFileManager, setIsDisableModelSwitch, setIsEnableAutoSavingState, + setIsControlNet, ]) // Dark Mode Hotkey diff --git a/lama_cleaner/app/src/adapters/inpainting.ts b/lama_cleaner/app/src/adapters/inpainting.ts index ed94c30..954dab8 100644 --- a/lama_cleaner/app/src/adapters/inpainting.ts +++ b/lama_cleaner/app/src/adapters/inpainting.ts @@ -87,6 +87,12 @@ export default async function inpaint( fd.append('p2pImageGuidanceScale', settings.p2pImageGuidanceScale.toString()) fd.append('p2pGuidanceScale', settings.p2pGuidanceScale.toString()) + // ControlNet + fd.append( + 'controlnet_conditioning_scale', + settings.controlnetConditioningScale.toString() + ) + if (sizeLimit === undefined) { fd.append('sizeLimit', '1080') } else { @@ -116,6 +122,12 @@ export function getIsDisableModelSwitch() { }) } +export function getIsControlNet() { + return fetch(`${API_ENDPOINT}/is_controlnet`, { + method: 'GET', + }) +} + export function getEnableFileManager() { return fetch(`${API_ENDPOINT}/is_enable_file_manager`, { method: 'GET', diff --git a/lama_cleaner/app/src/components/SidePanel/SidePanel.tsx b/lama_cleaner/app/src/components/SidePanel/SidePanel.tsx index 59c58ef..97d8049 100644 --- a/lama_cleaner/app/src/components/SidePanel/SidePanel.tsx +++ b/lama_cleaner/app/src/components/SidePanel/SidePanel.tsx @@ -3,6 +3,7 @@ import { useRecoilState, useRecoilValue } from 'recoil' import * as PopoverPrimitive from '@radix-ui/react-popover' import { useToggle } from 'react-use' import { + isControlNetState, isInpaintingState, negativePropmtState, propmtState, @@ -26,6 +27,7 @@ const SidePanel = () => { useRecoilState(negativePropmtState) const isInpainting = useRecoilValue(isInpaintingState) const prompt = useRecoilValue(propmtState) + const isControlNet = useRecoilValue(isControlNetState) const handleOnInput = (evt: FormEvent) => { evt.preventDefault() @@ -115,6 +117,22 @@ const SidePanel = () => { }} /> + {isControlNet && ( + { + const val = value.length === 0 ? 0 : parseFloat(value) + setSettingState(old => { + return { ...old, controlnetConditioningScale: val } + }) + }} + /> + )} + ({ @@ -70,6 +71,7 @@ export const appState = atom({ enableFileManager: false, gifImage: undefined, brushSize: 40, + isControlNet: false, }, }) @@ -239,6 +241,18 @@ export const isDisableModelSwitchState = selector({ }, }) +export const isControlNetState = selector({ + key: 'isControlNetState', + get: ({ get }) => { + const app = get(appState) + return app.isControlNet + }, + set: ({ get, set }, newValue: any) => { + const app = get(appState) + set(appState, { ...app, isControlNet: newValue }) + }, +}) + export const isEnableAutoSavingState = selector({ key: 'isEnableAutoSavingState', get: ({ get }) => { @@ -379,6 +393,9 @@ export interface Settings { p2pSteps: number p2pImageGuidanceScale: number p2pGuidanceScale: number + + // ControlNet + controlnetConditioningScale: number } const defaultHDSettings: ModelsHDSettings = { @@ -482,6 +499,7 @@ export enum SDSampler { kEuler = 'k_euler', kEulerA = 'k_euler_a', dpmPlusPlus = 'dpm++', + uni_pc = 'uni_pc', } export enum SDMode { @@ -510,7 +528,7 @@ export const settingStateDefault: Settings = { sdStrength: 0.75, sdSteps: 50, sdGuidanceScale: 7.5, - sdSampler: SDSampler.pndm, + sdSampler: SDSampler.uni_pc, sdSeed: 42, sdSeedFixed: false, sdNumSamples: 1, @@ -533,6 +551,9 @@ export const settingStateDefault: Settings = { p2pSteps: 50, p2pImageGuidanceScale: 1.5, p2pGuidanceScale: 7.5, + + // ControlNet + controlnetConditioningScale: 0.4, } const localStorageEffect = diff --git a/lama_cleaner/const.py b/lama_cleaner/const.py index af9375c..149a2c0 100644 --- a/lama_cleaner/const.py +++ b/lama_cleaner/const.py @@ -6,7 +6,8 @@ MPS_SUPPORT_MODELS = [ "anything4", "realisticVision1.4", "sd2", - "paint_by_example" + "paint_by_example", + "controlnet", ] DEFAULT_MODEL = "lama" @@ -24,10 +25,12 @@ AVAILABLE_MODELS = [ "sd2", "paint_by_example", "instruct_pix2pix", + "controlnet", ] +SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"] AVAILABLE_DEVICES = ["cuda", "cpu", "mps"] -DEFAULT_DEVICE = 'cuda' +DEFAULT_DEVICE = "cuda" NO_HALF_HELP = """ Using full precision model. @@ -46,6 +49,10 @@ SD_CPU_TEXTENCODER_HELP = """ Run Stable Diffusion text encoder model on CPU to save GPU memory. """ +SD_CONTROLNET_HELP = """ +Run Stable Diffusion 1.5 inpainting model with controlNet-canny model. +""" + LOCAL_FILES_ONLY_HELP = """ Use local files only, not connect to Hugging Face server. (sd/paint_by_example) """ @@ -55,8 +62,7 @@ Enable xFormers optimizations. Requires xformers package has been installed. See """ DEFAULT_MODEL_DIR = os.getenv( - "XDG_CACHE_HOME", - os.path.join(os.path.expanduser("~"), ".cache") + "XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache") ) MODEL_DIR_HELP = """ Model download directory (by setting XDG_CACHE_HOME environment variable), by default model downloaded to ~/.cache diff --git a/lama_cleaner/model/controlnet.py b/lama_cleaner/model/controlnet.py new file mode 100644 index 0000000..46030cc --- /dev/null +++ b/lama_cleaner/model/controlnet.py @@ -0,0 +1,157 @@ +import PIL.Image +import cv2 +import numpy as np +import torch +from diffusers import ( + ControlNetModel, +) +from loguru import logger + +from lama_cleaner.model.base import DiffusionInpaintModel +from lama_cleaner.model.utils import torch_gc, get_scheduler +from lama_cleaner.schema import Config + + +class CPUTextEncoderWrapper: + def __init__(self, text_encoder, torch_dtype): + self.config = text_encoder.config + self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True) + self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) + self.torch_dtype = torch_dtype + del text_encoder + torch_gc() + + def __call__(self, x, **kwargs): + input_device = x.device + return [ + self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0] + .to(input_device) + .to(self.torch_dtype) + ] + + @property + def dtype(self): + return self.torch_dtype + + +NAMES_MAP = { + "sd1.5": "runwayml/stable-diffusion-inpainting", + "anything4": "Sanster/anything-4.0-inpainting", + "realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting", +} + + +class ControlNet(DiffusionInpaintModel): + name = "controlnet" + pad_mod = 8 + min_size = 512 + + def init_model(self, device: torch.device, **kwargs): + from .pipeline import StableDiffusionControlNetInpaintPipeline + + model_id = NAMES_MAP[kwargs["name"]] + fp16 = not kwargs.get("no_half", False) + + model_kwargs = { + "local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"]) + } + if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False): + logger.info("Disable Stable Diffusion Model NSFW checker") + model_kwargs.update( + dict( + safety_checker=None, + feature_extractor=None, + requires_safety_checker=False, + ) + ) + + use_gpu = device == torch.device("cuda") and torch.cuda.is_available() + torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32 + + controlnet = ControlNetModel.from_pretrained( + f"lllyasviel/sd-controlnet-canny", torch_dtype=torch_dtype + ) + self.model = StableDiffusionControlNetInpaintPipeline.from_pretrained( + model_id, + controlnet=controlnet, + revision="fp16" if use_gpu and fp16 else "main", + torch_dtype=torch_dtype, + **model_kwargs, + ) + + # https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing + self.model.enable_attention_slicing() + # https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention + if kwargs.get("enable_xformers", False): + self.model.enable_xformers_memory_efficient_attention() + + if kwargs.get("cpu_offload", False) and use_gpu: + logger.info("Enable sequential cpu offload") + self.model.enable_sequential_cpu_offload(gpu_id=0) + else: + self.model = self.model.to(device) + if kwargs["sd_cpu_textencoder"]: + logger.info("Run Stable Diffusion TextEncoder on CPU") + self.model.text_encoder = CPUTextEncoderWrapper( + self.model.text_encoder, torch_dtype + ) + + self.callback = kwargs.pop("callback", None) + + def forward(self, image, mask, config: Config): + """Input image and output image have same size + image: [H, W, C] RGB + mask: [H, W, 1] 255 means area to repaint + return: BGR IMAGE + """ + + scheduler_config = self.model.scheduler.config + scheduler = get_scheduler(config.sd_sampler, scheduler_config) + self.model.scheduler = scheduler + + if config.sd_mask_blur != 0: + k = 2 * config.sd_mask_blur + 1 + mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis] + + img_h, img_w = image.shape[:2] + + canny_image = cv2.Canny(image, 100, 200) + canny_image = canny_image[:, :, None] + canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2) + canny_image = PIL.Image.fromarray(canny_image) + mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L") + image = PIL.Image.fromarray(image) + + output = self.model( + image=image, + control_image=canny_image, + prompt=config.prompt, + negative_prompt=config.negative_prompt, + mask_image=mask_image, + num_inference_steps=config.sd_steps, + guidance_scale=config.sd_guidance_scale, + output_type="np.array", + callback=self.callback, + height=img_h, + width=img_w, + generator=torch.manual_seed(config.sd_seed), + controlnet_conditioning_scale=config.controlnet_conditioning_scale, + ).images[0] + + output = (output * 255).round().astype("uint8") + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + def forward_post_process(self, result, image, mask, config): + if config.sd_match_histograms: + result = self._match_histograms(result, image[:, :, ::-1], mask) + + if config.sd_mask_blur != 0: + k = 2 * config.sd_mask_blur + 1 + mask = cv2.GaussianBlur(mask, (k, k), 0) + return result, image, mask + + @staticmethod + def is_downloaded() -> bool: + # model will be downloaded when app start, and can't switch in frontend settings + return True diff --git a/lama_cleaner/model/pipeline/__init__.py b/lama_cleaner/model/pipeline/__init__.py new file mode 100644 index 0000000..9056bc6 --- /dev/null +++ b/lama_cleaner/model/pipeline/__init__.py @@ -0,0 +1,3 @@ +from .pipeline_stable_diffusion_controlnet_inpaint import ( + StableDiffusionControlNetInpaintPipeline, +) diff --git a/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py b/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py new file mode 100644 index 0000000..b2a181c --- /dev/null +++ b/lama_cleaner/model/pipeline/pipeline_stable_diffusion_controlnet_inpaint.py @@ -0,0 +1,585 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copy from https://github.com/mikonvergence/ControlNetInpaint/blob/main/src/pipeline_stable_diffusion_controlnet_inpaint.py + +import torch +import PIL.Image +import numpy as np + +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import * + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # !pip install opencv-python transformers accelerate + >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler + >>> from diffusers.utils import load_image + >>> import numpy as np + >>> import torch + + >>> import cv2 + >>> from PIL import Image + >>> # download an image + >>> image = load_image( + ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png" + ... ) + >>> image = np.array(image) + >>> mask_image = load_image( + ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png" + ... ) + >>> mask_image = np.array(mask_image) + >>> # get canny image + >>> canny_image = cv2.Canny(image, 100, 200) + >>> canny_image = canny_image[:, :, None] + >>> canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2) + >>> canny_image = Image.fromarray(canny_image) + + >>> # load control net and stable diffusion v1-5 + >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) + >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( + ... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16 + ... ) + + >>> # speed up diffusion process with faster scheduler and memory optimization + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) + >>> # remove following line if xformers is not installed + >>> pipe.enable_xformers_memory_efficient_attention() + + >>> pipe.enable_model_cpu_offload() + + >>> # generate image + >>> generator = torch.manual_seed(0) + >>> image = pipe( + ... "futuristic-looking doggo", + ... num_inference_steps=20, + ... generator=generator, + ... image=image, + ... control_image=canny_image, + ... mask_image=mask_image + ... ).images[0] + ``` +""" + + +def prepare_mask_and_masked_image(image, mask): + """ + Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be + converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the + ``image`` and ``1`` for the ``mask``. + The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be + binarized (``mask > 0.5``) and cast to ``torch.float32`` too. + Args: + image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint. + It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width`` + ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``. + mask (_type_): The mask to apply to the image, i.e. regions to inpaint. + It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width`` + ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``. + Raises: + ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask + should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions. + TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not + (ot the other way around). + Returns: + tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 + dimensions: ``batch x channels x height x width``. + """ + if isinstance(image, torch.Tensor): + if not isinstance(mask, torch.Tensor): + raise TypeError( + f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not" + ) + + # Batch single image + if image.ndim == 3: + assert ( + image.shape[0] == 3 + ), "Image outside a batch should be of shape (3, H, W)" + image = image.unsqueeze(0) + + # Batch and add channel dim for single mask + if mask.ndim == 2: + mask = mask.unsqueeze(0).unsqueeze(0) + + # Batch single mask or add channel dim + if mask.ndim == 3: + # Single batched mask, no channel dim or single mask not batched but channel dim + if mask.shape[0] == 1: + mask = mask.unsqueeze(0) + + # Batched masks no channel dim + else: + mask = mask.unsqueeze(1) + + assert ( + image.ndim == 4 and mask.ndim == 4 + ), "Image and Mask must have 4 dimensions" + assert ( + image.shape[-2:] == mask.shape[-2:] + ), "Image and Mask must have the same spatial dimensions" + assert ( + image.shape[0] == mask.shape[0] + ), "Image and Mask must have the same batch size" + + # Check image is in [-1, 1] + if image.min() < -1 or image.max() > 1: + raise ValueError("Image should be in [-1, 1] range") + + # Check mask is in [0, 1] + if mask.min() < 0 or mask.max() > 1: + raise ValueError("Mask should be in [0, 1] range") + + # Binarize mask + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + + # Image as float32 + image = image.to(dtype=torch.float32) + elif isinstance(mask, torch.Tensor): + raise TypeError( + f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not" + ) + else: + # preprocess image + if isinstance(image, (PIL.Image.Image, np.ndarray)): + image = [image] + + if isinstance(image, list) and isinstance(image[0], PIL.Image.Image): + image = [np.array(i.convert("RGB"))[None, :] for i in image] + image = np.concatenate(image, axis=0) + elif isinstance(image, list) and isinstance(image[0], np.ndarray): + image = np.concatenate([i[None, :] for i in image], axis=0) + + image = image.transpose(0, 3, 1, 2) + image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 + + # preprocess mask + if isinstance(mask, (PIL.Image.Image, np.ndarray)): + mask = [mask] + + if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image): + mask = np.concatenate( + [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0 + ) + mask = mask.astype(np.float32) / 255.0 + elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): + mask = np.concatenate([m[None, None, :] for m in mask], axis=0) + + mask[mask < 0.5] = 0 + mask[mask >= 0.5] = 1 + mask = torch.from_numpy(mask) + + masked_image = image * (mask < 0.5) + + return mask, masked_image + + +class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline): + r""" + Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance. + + This model inherits from [`StableDiffusionControlNetPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`]): + Provides additional conditioning to the unet during the denoising process + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + + def prepare_mask_latents( + self, + mask, + masked_image, + batch_size, + height, + width, + dtype, + device, + generator, + do_classifier_free_guidance, + ): + # resize the mask to latents shape as we concatenate the mask to the latents + # we do that before converting to dtype to avoid breaking in case we're using cpu_offload + # and half precision + mask = torch.nn.functional.interpolate( + mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor) + ) + mask = mask.to(device=device, dtype=dtype) + + masked_image = masked_image.to(device=device, dtype=dtype) + + # encode the mask image into latents space so we can concatenate it to the latents + if isinstance(generator, list): + masked_image_latents = [ + self.vae.encode(masked_image[i : i + 1]).latent_dist.sample( + generator=generator[i] + ) + for i in range(batch_size) + ] + masked_image_latents = torch.cat(masked_image_latents, dim=0) + else: + masked_image_latents = self.vae.encode(masked_image).latent_dist.sample( + generator=generator + ) + masked_image_latents = self.vae.config.scaling_factor * masked_image_latents + + # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method + if mask.shape[0] < batch_size: + if not batch_size % mask.shape[0] == 0: + raise ValueError( + "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" + f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" + " of masks that you pass is divisible by the total requested batch size." + ) + mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) + if masked_image_latents.shape[0] < batch_size: + if not batch_size % masked_image_latents.shape[0] == 0: + raise ValueError( + "The passed images and the required batch size don't match. Images are supposed to be duplicated" + f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." + " Make sure the number of images that you pass is divisible by the total requested batch size." + ) + masked_image_latents = masked_image_latents.repeat( + batch_size // masked_image_latents.shape[0], 1, 1, 1 + ) + + mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask + masked_image_latents = ( + torch.cat([masked_image_latents] * 2) + if do_classifier_free_guidance + else masked_image_latents + ) + + # aligning device to prevent device errors when concating it with the latent model input + masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) + return mask, masked_image_latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + control_image: Union[ + torch.FloatTensor, + PIL.Image.Image, + List[torch.FloatTensor], + List[PIL.Image.Image], + ] = None, + mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: float = 1.0, + ): + r""" + Function invoked when calling the pipeline for generation. + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will + be masked out with `mask_image` and repainted according to `prompt`. + control_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can + also be accepted as an image. The control image is automatically resized to fit the output image. + mask_image (`PIL.Image.Image`): + `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be + repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted + to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L) + instead of 3, so the expected shape would be `(B, H, W, 1)`. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. + Examples: + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height, width = self._default_height_width(height, width, control_image) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + control_image, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare image + control_image = self.prepare_image( + control_image, + width, + height, + batch_size * num_images_per_prompt, + num_images_per_prompt, + device, + self.controlnet.dtype, + ) + + if do_classifier_free_guidance: + control_image = torch.cat([control_image] * 2) + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.controlnet.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # EXTRA: prepare mask latents + mask, masked_image = prepare_mask_and_masked_image(image, mask_image) + mask, masked_image_latents = self.prepare_mask_latents( + mask, + masked_image, + batch_size * num_images_per_prompt, + height, + width, + prompt_embeds.dtype, + device, + generator, + do_classifier_free_guidance, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + + down_block_res_samples, mid_block_res_sample = self.controlnet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + controlnet_cond=control_image, + return_dict=False, + ) + + down_block_res_samples = [ + down_block_res_sample * controlnet_conditioning_scale + for down_block_res_sample in down_block_res_samples + ] + mid_block_res_sample *= controlnet_conditioning_scale + + # predict the noise residual + latent_model_input = torch.cat( + [latent_model_input, mask, masked_image_latents], dim=1 + ) + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs + ).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + if output_type == "latent": + image = latents + has_nsfw_concept = None + elif output_type == "pil": + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) + + # 10. Convert to PIL + image = self.numpy_to_pil(image) + else: + # 8. Post-processing + image = self.decode_latents(latents) + + # 9. Run safety checker + image, has_nsfw_concept = self.run_safety_checker( + image, device, prompt_embeds.dtype + ) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) diff --git a/lama_cleaner/model/sd.py b/lama_cleaner/model/sd.py index 019de9b..39ced29 100644 --- a/lama_cleaner/model/sd.py +++ b/lama_cleaner/model/sd.py @@ -1,22 +1,12 @@ -import random - import PIL.Image import cv2 import numpy as np import torch -from diffusers import ( - PNDMScheduler, - DDIMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, -) from loguru import logger from lama_cleaner.model.base import DiffusionInpaintModel -from lama_cleaner.model.utils import torch_gc, set_seed -from lama_cleaner.schema import Config, SDSampler +from lama_cleaner.model.utils import torch_gc, get_scheduler +from lama_cleaner.schema import Config class CPUTextEncoderWrapper: @@ -101,22 +91,7 @@ class SD(DiffusionInpaintModel): """ scheduler_config = self.model.scheduler.config - - if config.sd_sampler == SDSampler.ddim: - scheduler = DDIMScheduler.from_config(scheduler_config) - elif config.sd_sampler == SDSampler.pndm: - scheduler = PNDMScheduler.from_config(scheduler_config) - elif config.sd_sampler == SDSampler.k_lms: - scheduler = LMSDiscreteScheduler.from_config(scheduler_config) - elif config.sd_sampler == SDSampler.k_euler: - scheduler = EulerDiscreteScheduler.from_config(scheduler_config) - elif config.sd_sampler == SDSampler.k_euler_a: - scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config) - elif config.sd_sampler == SDSampler.dpm_plus_plus: - scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config) - else: - raise ValueError(config.sd_sampler) - + scheduler = get_scheduler(config.sd_sampler, scheduler_config) self.model.scheduler = scheduler if config.sd_mask_blur != 0: diff --git a/lama_cleaner/model/utils.py b/lama_cleaner/model/utils.py index 52961b1..352af04 100644 --- a/lama_cleaner/model/utils.py +++ b/lama_cleaner/model/utils.py @@ -7,17 +7,35 @@ import numpy as np import collections from itertools import repeat +from diffusers import ( + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + UniPCMultistepScheduler, +) + +from lama_cleaner.schema import SDSampler from torch import conv2d, conv_transpose2d -def make_beta_schedule(device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): +def make_beta_schedule( + device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): if schedule == "linear": betas = ( - torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 ) elif schedule == "cosine": - timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s).to(device) + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ).to(device) alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = torch.cos(alphas).pow(2).to(device) alphas = alphas / alphas[0] @@ -25,9 +43,14 @@ def make_beta_schedule(device, schedule, n_timestep, linear_start=1e-4, linear_e betas = np.clip(betas, a_min=0, a_max=0.999) elif schedule == "sqrt_linear": - betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) elif schedule == "sqrt": - betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) else: raise ValueError(f"schedule '{schedule}' unknown.") return betas.numpy() @@ -39,33 +62,47 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) # according the the formula provided in https://arxiv.org/abs/2010.02502 - sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) if verbose: - print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') - print(f'For the chosen value of eta, which is {eta}, ' - f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) return sigmas, alphas, alphas_prev -def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): - if ddim_discr_method == 'uniform': +def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True +): + if ddim_discr_method == "uniform": c = num_ddpm_timesteps // num_ddim_timesteps ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) - elif ddim_discr_method == 'quad': - ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) else: - raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) # assert ddim_timesteps.shape[0] == num_ddim_timesteps # add one to get the final alpha values right (the ones from first scale to data during sampling) steps_out = ddim_timesteps + 1 if verbose: - print(f'Selected timesteps for ddim sampler: {steps_out}') + print(f"Selected timesteps for ddim sampler: {steps_out}") return steps_out def noise_like(shape, device, repeat=False): - repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) noise = lambda: torch.randn(shape, device=device) return repeat_noise() if repeat else noise() @@ -81,7 +118,9 @@ def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=Fal """ half = dim // 2 freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half ).to(device=device) args = timesteps[:, None].float() * freqs[None] @@ -115,9 +154,8 @@ class EasyDict(dict): del self[name] -def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): - """Slow reference implementation of `bias_act()` using standard TensorFlow ops. - """ +def _bias_act_ref(x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None): + """Slow reference implementation of `bias_act()` using standard TensorFlow ops.""" assert isinstance(x, torch.Tensor) assert clamp is None or clamp >= 0 spec = activation_funcs[act] @@ -147,7 +185,9 @@ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=N return x -def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='ref'): +def bias_act( + x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None, impl="ref" +): r"""Fused bias and activation function. Adds bias `b` to activation tensor `x`, evaluates activation function `act`, @@ -178,8 +218,10 @@ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, Tensor of the same shape and datatype as `x`. """ assert isinstance(x, torch.Tensor) - assert impl in ['ref', 'cuda'] - return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) + assert impl in ["ref", "cuda"] + return _bias_act_ref( + x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp + ) def _get_filter_size(f): @@ -223,7 +265,14 @@ def _parse_padding(padding): return padx0, padx1, pady0, pady1 -def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): +def setup_filter( + f, + device=torch.device("cpu"), + normalize=True, + flip_filter=False, + gain=1, + separable=None, +): r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. Args: @@ -255,7 +304,7 @@ def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=Fals # Separable? if separable is None: - separable = (f.ndim == 1 and f.numel() >= 8) + separable = f.ndim == 1 and f.numel() >= 8 if f.ndim == 1 and not separable: f = f.ger(f) assert f.ndim == (1 if separable else 2) @@ -282,27 +331,82 @@ def _ntuple(n): to_2tuple = _ntuple(2) activation_funcs = { - 'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), - 'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, - ref='y', has_2nd_grad=False), - 'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, - def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), - 'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', - has_2nd_grad=True), - 'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', - has_2nd_grad=True), - 'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', - has_2nd_grad=True), - 'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', - has_2nd_grad=True), - 'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, - ref='y', has_2nd_grad=True), - 'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', - has_2nd_grad=True), + "linear": EasyDict( + func=lambda x, **_: x, + def_alpha=0, + def_gain=1, + cuda_idx=1, + ref="", + has_2nd_grad=False, + ), + "relu": EasyDict( + func=lambda x, **_: torch.nn.functional.relu(x), + def_alpha=0, + def_gain=np.sqrt(2), + cuda_idx=2, + ref="y", + has_2nd_grad=False, + ), + "lrelu": EasyDict( + func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), + def_alpha=0.2, + def_gain=np.sqrt(2), + cuda_idx=3, + ref="y", + has_2nd_grad=False, + ), + "tanh": EasyDict( + func=lambda x, **_: torch.tanh(x), + def_alpha=0, + def_gain=1, + cuda_idx=4, + ref="y", + has_2nd_grad=True, + ), + "sigmoid": EasyDict( + func=lambda x, **_: torch.sigmoid(x), + def_alpha=0, + def_gain=1, + cuda_idx=5, + ref="y", + has_2nd_grad=True, + ), + "elu": EasyDict( + func=lambda x, **_: torch.nn.functional.elu(x), + def_alpha=0, + def_gain=1, + cuda_idx=6, + ref="y", + has_2nd_grad=True, + ), + "selu": EasyDict( + func=lambda x, **_: torch.nn.functional.selu(x), + def_alpha=0, + def_gain=1, + cuda_idx=7, + ref="y", + has_2nd_grad=True, + ), + "softplus": EasyDict( + func=lambda x, **_: torch.nn.functional.softplus(x), + def_alpha=0, + def_gain=1, + cuda_idx=8, + ref="y", + has_2nd_grad=True, + ), + "swish": EasyDict( + func=lambda x, **_: torch.sigmoid(x) * x, + def_alpha=0, + def_gain=np.sqrt(2), + cuda_idx=9, + ref="x", + has_2nd_grad=True, + ), } -def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): +def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"): r"""Pad, upsample, filter, and downsample a batch of 2D images. Performs the following sequence of operations for each channel: @@ -344,12 +448,13 @@ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cu """ # assert isinstance(x, torch.Tensor) # assert impl in ['ref', 'cuda'] - return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) + return _upfirdn2d_ref( + x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain + ) def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): - """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. - """ + """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.""" # Validate arguments. assert isinstance(x, torch.Tensor) and x.ndim == 4 if f is None: @@ -372,8 +477,15 @@ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) # Pad or crop. - x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) - x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)] + x = torch.nn.functional.pad( + x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)] + ) + x = x[ + :, + :, + max(-pady0, 0) : x.shape[2] - max(-pady1, 0), + max(-padx0, 0) : x.shape[3] - max(-padx1, 0), + ] # Setup filter. f = f * (gain ** (f.ndim / 2)) @@ -394,7 +506,7 @@ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): return x -def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): +def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl="cuda"): r"""Downsample a batch of 2D images using the given 2D FIR filter. By default, the result is padded so that its shape is a fraction of the input. @@ -431,10 +543,12 @@ def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda' pady0 + (fh - downy + 1) // 2, pady1 + (fh - downy) // 2, ] - return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) + return upfirdn2d( + x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl + ) -def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): +def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl="cuda"): r"""Upsample a batch of 2D images using the given 2D FIR filter. By default, the result is padded so that its shape is a multiple of the input. @@ -471,7 +585,15 @@ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): pady0 + (fh + upy - 1) // 2, pady1 + (fh - upy) // 2, ] - return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain * upx * upy, impl=impl) + return upfirdn2d( + x, + f, + up=up, + padding=p, + flip_filter=flip_filter, + gain=gain * upx * upy, + impl=impl, + ) class MinibatchStdLayer(torch.nn.Module): @@ -482,13 +604,17 @@ class MinibatchStdLayer(torch.nn.Module): def forward(self, x): N, C, H, W = x.shape - G = torch.min(torch.as_tensor(self.group_size), - torch.as_tensor(N)) if self.group_size is not None else N + G = ( + torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N)) + if self.group_size is not None + else N + ) F = self.num_channels c = C // F - y = x.reshape(G, -1, F, c, H, - W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. + y = x.reshape( + G, -1, F, c, H, W + ) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. @@ -500,17 +626,24 @@ class MinibatchStdLayer(torch.nn.Module): class FullyConnectedLayer(torch.nn.Module): - def __init__(self, - in_features, # Number of input features. - out_features, # Number of output features. - bias=True, # Apply additive bias before the activation function? - activation='linear', # Activation function: 'relu', 'lrelu', etc. - lr_multiplier=1, # Learning rate multiplier. - bias_init=0, # Initial value for the additive bias. - ): + def __init__( + self, + in_features, # Number of input features. + out_features, # Number of output features. + bias=True, # Apply additive bias before the activation function? + activation="linear", # Activation function: 'relu', 'lrelu', etc. + lr_multiplier=1, # Learning rate multiplier. + bias_init=0, # Initial value for the additive bias. + ): super().__init__() - self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) - self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None + self.weight = torch.nn.Parameter( + torch.randn([out_features, in_features]) / lr_multiplier + ) + self.bias = ( + torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) + if bias + else None + ) self.activation = activation self.weight_gain = lr_multiplier / np.sqrt(in_features) @@ -522,7 +655,7 @@ class FullyConnectedLayer(torch.nn.Module): if b is not None and self.bias_gain != 1: b = b * self.bias_gain - if self.activation == 'linear' and b is not None: + if self.activation == "linear" and b is not None: # out = torch.addmm(b.unsqueeze(0), x, w.t()) x = x.matmul(w.t()) out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)]) @@ -532,22 +665,33 @@ class FullyConnectedLayer(torch.nn.Module): return out -def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): - """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. - """ +def _conv2d_wrapper( + x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True +): + """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.""" out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) # Flip weight if requested. - if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). + if ( + not flip_weight + ): # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). w = w.flip([2, 3]) # Workaround performance pitfall in cuDNN 8.0.5, triggered when using # 1x1 kernel + memory_format=channels_last + less than 64 channels. - if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: + if ( + kw == 1 + and kh == 1 + and stride == 1 + and padding in [0, [0, 0], (0, 0)] + and not transpose + ): if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: if out_channels <= 4 and groups == 1: in_shape = x.shape - x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) + x = w.squeeze(3).squeeze(2) @ x.reshape( + [in_shape[0], in_channels_per_group, -1] + ) x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) else: x = x.to(memory_format=torch.contiguous_format) @@ -560,7 +704,9 @@ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_w return op(x, w, stride=stride, padding=padding, groups=groups) -def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): +def conv2d_resample( + x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False +): r"""2D convolution with optional up/downsampling. Padding is performed only once at the beginning, not between the operations. @@ -587,7 +733,9 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight # Validate arguments. assert isinstance(x, torch.Tensor) and (x.ndim == 4) assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) - assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) + assert f is None or ( + isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32 + ) assert isinstance(up, int) and (up >= 1) assert isinstance(down, int) and (down >= 1) # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}" @@ -610,20 +758,31 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. if kw == 1 and kh == 1 and (down > 1 and up == 1): - x = upfirdn2d(x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter) + x = upfirdn2d( + x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter + ) x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) return x # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. if kw == 1 and kh == 1 and (up > 1 and down == 1): x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) - x = upfirdn2d(x=x, f=f, up=up, padding=[px0, px1, py0, py1], gain=up ** 2, flip_filter=flip_filter) + x = upfirdn2d( + x=x, + f=f, + up=up, + padding=[px0, px1, py0, py1], + gain=up**2, + flip_filter=flip_filter, + ) return x # Fast path: downsampling only => use strided convolution. if down > 1 and up == 1: x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter) - x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) + x = _conv2d_wrapper( + x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight + ) return x # Fast path: upsampling with optional downsampling => use transpose strided convolution. @@ -633,17 +792,31 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight else: w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) w = w.transpose(1, 2) - w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) + w = w.reshape( + groups * in_channels_per_group, out_channels // groups, kh, kw + ) px0 -= kw - 1 px1 -= kw - up py0 -= kh - 1 py1 -= kh - up pxt = max(min(-px0, -px1), 0) pyt = max(min(-py0, -py1), 0) - x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt, pxt], groups=groups, transpose=True, - flip_weight=(not flip_weight)) - x = upfirdn2d(x=x, f=f, padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], gain=up ** 2, - flip_filter=flip_filter) + x = _conv2d_wrapper( + x=x, + w=w, + stride=up, + padding=[pyt, pxt], + groups=groups, + transpose=True, + flip_weight=(not flip_weight), + ) + x = upfirdn2d( + x=x, + f=f, + padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], + gain=up**2, + flip_filter=flip_filter, + ) if down > 1: x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) return x @@ -651,11 +824,19 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. if up == 1 and down == 1: if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: - return _conv2d_wrapper(x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight) + return _conv2d_wrapper( + x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight + ) # Fallback: Generic reference implementation. - x = upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0, px1, py0, py1], gain=up ** 2, - flip_filter=flip_filter) + x = upfirdn2d( + x=x, + f=(f if up > 1 else None), + up=up, + padding=[px0, px1, py0, py1], + gain=up**2, + flip_filter=flip_filter, + ) x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) if down > 1: x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) @@ -663,50 +844,68 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight class Conv2dLayer(torch.nn.Module): - def __init__(self, - in_channels, # Number of input channels. - out_channels, # Number of output channels. - kernel_size, # Width and height of the convolution kernel. - bias=True, # Apply additive bias before the activation function? - activation='linear', # Activation function: 'relu', 'lrelu', etc. - up=1, # Integer upsampling factor. - down=1, # Integer downsampling factor. - resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. - conv_clamp=None, # Clamp the output to +-X, None = disable clamping. - channels_last=False, # Expect the input to have memory_format=channels_last? - trainable=True, # Update the weights of this layer during training? - ): + def __init__( + self, + in_channels, # Number of input channels. + out_channels, # Number of output channels. + kernel_size, # Width and height of the convolution kernel. + bias=True, # Apply additive bias before the activation function? + activation="linear", # Activation function: 'relu', 'lrelu', etc. + up=1, # Integer upsampling factor. + down=1, # Integer downsampling factor. + resample_filter=[ + 1, + 3, + 3, + 1, + ], # Low-pass filter to apply when resampling activations. + conv_clamp=None, # Clamp the output to +-X, None = disable clamping. + channels_last=False, # Expect the input to have memory_format=channels_last? + trainable=True, # Update the weights of this layer during training? + ): super().__init__() self.activation = activation self.up = up self.down = down - self.register_buffer('resample_filter', setup_filter(resample_filter)) + self.register_buffer("resample_filter", setup_filter(resample_filter)) self.conv_clamp = conv_clamp self.padding = kernel_size // 2 - self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) + self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2)) self.act_gain = activation_funcs[activation].def_gain - memory_format = torch.channels_last if channels_last else torch.contiguous_format - weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format) + memory_format = ( + torch.channels_last if channels_last else torch.contiguous_format + ) + weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to( + memory_format=memory_format + ) bias = torch.zeros([out_channels]) if bias else None if trainable: self.weight = torch.nn.Parameter(weight) self.bias = torch.nn.Parameter(bias) if bias is not None else None else: - self.register_buffer('weight', weight) + self.register_buffer("weight", weight) if bias is not None: - self.register_buffer('bias', bias) + self.register_buffer("bias", bias) else: self.bias = None def forward(self, x, gain=1): w = self.weight * self.weight_gain - x = conv2d_resample(x=x, w=w, f=self.resample_filter, up=self.up, down=self.down, - padding=self.padding) + x = conv2d_resample( + x=x, + w=w, + f=self.resample_filter, + up=self.up, + down=self.down, + padding=self.padding, + ) act_gain = self.act_gain * gain act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None - out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp) + out = bias_act( + x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp + ) return out @@ -721,3 +920,22 @@ def set_seed(seed: int): np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) + + +def get_scheduler(sd_sampler, scheduler_config): + if sd_sampler == SDSampler.ddim: + return DDIMScheduler.from_config(scheduler_config) + elif sd_sampler == SDSampler.pndm: + return PNDMScheduler.from_config(scheduler_config) + elif sd_sampler == SDSampler.k_lms: + return LMSDiscreteScheduler.from_config(scheduler_config) + elif sd_sampler == SDSampler.k_euler: + return EulerDiscreteScheduler.from_config(scheduler_config) + elif sd_sampler == SDSampler.k_euler_a: + return EulerAncestralDiscreteScheduler.from_config(scheduler_config) + elif sd_sampler == SDSampler.dpm_plus_plus: + return DPMSolverMultistepScheduler.from_config(scheduler_config) + elif sd_sampler == SDSampler.uni_pc: + return UniPCMultistepScheduler.from_config(scheduler_config) + else: + raise ValueError(sd_sampler) diff --git a/lama_cleaner/model_manager.py b/lama_cleaner/model_manager.py index bdf0d78..ac9abe3 100644 --- a/lama_cleaner/model_manager.py +++ b/lama_cleaner/model_manager.py @@ -1,7 +1,9 @@ import torch import gc +from lama_cleaner.const import SD15_MODELS from lama_cleaner.helper import switch_mps_device +from lama_cleaner.model.controlnet import ControlNet from lama_cleaner.model.fcf import FcF from lama_cleaner.model.lama import LaMa from lama_cleaner.model.ldm import LDM @@ -20,7 +22,7 @@ models = { "zits": ZITS, "mat": MAT, "fcf": FcF, - "sd1.5": SD15, + SD15.name: SD15, Anything4.name: Anything4, RealisticVision14.name: RealisticVision14, "cv2": OpenCV2, @@ -39,6 +41,9 @@ class ModelManager: self.model = self.init_model(name, device, **kwargs) def init_model(self, name: str, device, **kwargs): + if name in SD15_MODELS and kwargs.get("sd_controlnet", False): + return ControlNet(device, **{**kwargs, "name": name}) + if name in models: model = models[name](device, **kwargs) else: diff --git a/lama_cleaner/parse_args.py b/lama_cleaner/parse_args.py index 394b3e1..6a3fb1a 100644 --- a/lama_cleaner/parse_args.py +++ b/lama_cleaner/parse_args.py @@ -5,25 +5,7 @@ from pathlib import Path from loguru import logger -from lama_cleaner.const import ( - AVAILABLE_MODELS, - NO_HALF_HELP, - CPU_OFFLOAD_HELP, - DISABLE_NSFW_HELP, - SD_CPU_TEXTENCODER_HELP, - LOCAL_FILES_ONLY_HELP, - AVAILABLE_DEVICES, - ENABLE_XFORMERS_HELP, - MODEL_DIR_HELP, - OUTPUT_DIR_HELP, - INPUT_HELP, - GUI_HELP, - DEFAULT_DEVICE, - NO_GUI_AUTO_CLOSE_HELP, - DEFAULT_MODEL_DIR, - DEFAULT_MODEL, - MPS_SUPPORT_MODELS, -) +from lama_cleaner.const import * from lama_cleaner.runtime import dump_environment_info @@ -55,6 +37,9 @@ def parse_args(): parser.add_argument( "--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP ) + parser.add_argument( + "--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP + ) parser.add_argument( "--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP ) @@ -133,6 +118,10 @@ def parse_args(): "torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation" ) + if args.sd_controlnet: + if args.model not in SD15_MODELS: + logger.warning(f"--sd_controlnet only support {SD15_MODELS}") + if args.model_dir and args.model_dir is not None: if os.path.isfile(args.model_dir): parser.error(f"invalid --model-dir: {args.model_dir} is a file") diff --git a/lama_cleaner/schema.py b/lama_cleaner/schema.py index 96edf60..2674316 100644 --- a/lama_cleaner/schema.py +++ b/lama_cleaner/schema.py @@ -24,9 +24,10 @@ class SDSampler(str, Enum): ddim = "ddim" pndm = "pndm" k_lms = "k_lms" - k_euler = 'k_euler' - k_euler_a = 'k_euler_a' - dpm_plus_plus = 'dpm++' + k_euler = "k_euler" + k_euler_a = "k_euler_a" + dpm_plus_plus = "dpm++" + uni_pc = "uni_pc" class Config(BaseModel): @@ -71,14 +72,14 @@ class Config(BaseModel): # Higher guidance scale encourages to generate images that are closely linked # to the text prompt, usually at the expense of lower image quality. sd_guidance_scale: float = 7.5 - sd_sampler: str = SDSampler.ddim + sd_sampler: str = SDSampler.uni_pc # -1 mean random seed sd_seed: int = 42 sd_match_histograms: bool = False # Configs for opencv inpainting # opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07 - cv2_flag: str = 'INPAINT_NS' + cv2_flag: str = "INPAINT_NS" cv2_radius: int = 4 # Paint by Example @@ -93,3 +94,6 @@ class Config(BaseModel): p2p_steps: int = 50 p2p_image_guidance_scale: float = 1.5 p2p_guidance_scale: float = 7.5 + + # ControlNet + controlnet_conditioning_scale: float = 0.4 diff --git a/lama_cleaner/server.py b/lama_cleaner/server.py index 7fe31c2..61a5d33 100644 --- a/lama_cleaner/server.py +++ b/lama_cleaner/server.py @@ -18,6 +18,7 @@ import numpy as np from loguru import logger from watchdog.events import FileSystemEventHandler +from lama_cleaner.const import SD15_MODELS from lama_cleaner.interactive_seg import InteractiveSeg, Click from lama_cleaner.make_gif import make_compare_gif from lama_cleaner.model_manager import ModelManager @@ -88,6 +89,7 @@ interactive_seg_model: InteractiveSeg = None device = None input_image_path: str = None is_disable_model_switch: bool = False +is_controlnet: bool = False is_enable_file_manager: bool = False is_enable_auto_saving: bool = False is_desktop: bool = False @@ -257,6 +259,7 @@ def process(): p2p_steps=form["p2pSteps"], p2p_image_guidance_scale=form["p2pImageGuidanceScale"], p2p_guidance_scale=form["p2pGuidanceScale"], + controlnet_conditioning_scale=form["controlnet_conditioning_scale"], ) if config.sd_seed == -1: @@ -347,6 +350,12 @@ def current_model(): return model.name, 200 +@app.route("/is_controlnet") +def get_is_controlnet(): + res = "true" if is_controlnet else "false" + return res, 200 + + @app.route("/is_disable_model_switch") def get_is_disable_model_switch(): res = "true" if is_disable_model_switch else "false" @@ -427,6 +436,9 @@ def main(args): global thumb global output_dir global is_enable_auto_saving + global is_controlnet + if args.sd_controlnet and args.model in SD15_MODELS: + is_controlnet = True output_dir = args.output_dir if output_dir is not None: @@ -464,6 +476,7 @@ def main(args): model = ModelManager( name=args.model, + sd_controlnet=args.sd_controlnet, device=device, no_half=args.no_half, hf_access_token=args.hf_access_token, diff --git a/lama_cleaner/web_config.py b/lama_cleaner/web_config.py index a01cbf3..1643f27 100644 --- a/lama_cleaner/web_config.py +++ b/lama_cleaner/web_config.py @@ -6,25 +6,7 @@ import gradio as gr from loguru import logger from pydantic import BaseModel -from lama_cleaner.const import ( - AVAILABLE_MODELS, - AVAILABLE_DEVICES, - CPU_OFFLOAD_HELP, - NO_HALF_HELP, - DISABLE_NSFW_HELP, - SD_CPU_TEXTENCODER_HELP, - LOCAL_FILES_ONLY_HELP, - ENABLE_XFORMERS_HELP, - MODEL_DIR_HELP, - OUTPUT_DIR_HELP, - INPUT_HELP, - GUI_HELP, - DEFAULT_MODEL, - DEFAULT_DEVICE, - NO_GUI_AUTO_CLOSE_HELP, - DEFAULT_MODEL_DIR, - MPS_SUPPORT_MODELS, -) +from lama_cleaner.const import * _config_file = None @@ -33,6 +15,7 @@ class Config(BaseModel): host: str = "127.0.0.1" port: int = 8080 model: str = DEFAULT_MODEL + sd_controlnet: bool = False device: str = DEFAULT_DEVICE gui: bool = False no_gui_auto_close: bool = False @@ -59,6 +42,7 @@ def save_config( host, port, model, + sd_controlnet, device, gui, no_gui_auto_close, @@ -127,6 +111,9 @@ def main(config_file: str): disable_nsfw = gr.Checkbox( init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}" ) + sd_controlnet = gr.Checkbox( + init_config.sd_controlnet, label=f"{SD_CONTROLNET_HELP}" + ) sd_cpu_textencoder = gr.Checkbox( init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}" ) @@ -149,6 +136,7 @@ def main(config_file: str): host, port, model, + sd_controlnet, device, gui, no_gui_auto_close, diff --git a/requirements.txt b/requirements.txt index accfd58..27127f9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ pytest yacs markupsafe==2.0.1 scikit-image==0.19.3 -diffusers[torch]==0.12.1 +diffusers[torch]==0.14.0 transformers>=4.25.1 watchdog==2.2.1 gradio