add controlnet inpainting

This commit is contained in:
Qing 2023-03-19 22:40:23 +08:00
parent 61928c9861
commit 5f4c62ac18
17 changed files with 1197 additions and 186 deletions

View File

@ -1,11 +1,18 @@
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import warnings import warnings
warnings.simplefilter("ignore", UserWarning) warnings.simplefilter("ignore", UserWarning)
from lama_cleaner.parse_args import parse_args from lama_cleaner.parse_args import parse_args
def entry_point(): def entry_point():
args = parse_args() args = parse_args()
# To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers # To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
# https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18 # https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
from lama_cleaner.server import main from lama_cleaner.server import main
main(args) main(args)

View File

@ -7,6 +7,7 @@ import Workspace from './components/Workspace'
import { import {
enableFileManagerState, enableFileManagerState,
fileState, fileState,
isControlNetState,
isDisableModelSwitchState, isDisableModelSwitchState,
isEnableAutoSavingState, isEnableAutoSavingState,
toastState, toastState,
@ -17,6 +18,7 @@ import useHotKey from './hooks/useHotkey'
import { import {
getEnableAutoSaving, getEnableAutoSaving,
getEnableFileManager, getEnableFileManager,
getIsControlNet,
getIsDisableModelSwitch, getIsDisableModelSwitch,
isDesktop, isDesktop,
} from './adapters/inpainting' } from './adapters/inpainting'
@ -37,6 +39,7 @@ function App() {
const setIsDisableModelSwitch = useSetRecoilState(isDisableModelSwitchState) const setIsDisableModelSwitch = useSetRecoilState(isDisableModelSwitchState)
const setEnableFileManager = useSetRecoilState(enableFileManagerState) const setEnableFileManager = useSetRecoilState(enableFileManagerState)
const setIsEnableAutoSavingState = useSetRecoilState(isEnableAutoSavingState) const setIsEnableAutoSavingState = useSetRecoilState(isEnableAutoSavingState)
const setIsControlNet = useSetRecoilState(isControlNetState)
// Set Input Image // Set Input Image
useEffect(() => { useEffect(() => {
@ -75,10 +78,17 @@ function App() {
setIsEnableAutoSavingState(isEnabled === 'true') setIsEnableAutoSavingState(isEnabled === 'true')
} }
fetchData3() fetchData3()
const fetchData4 = async () => {
const isEnabled = await getIsControlNet().then(res => res.text())
setIsControlNet(isEnabled === 'true')
}
fetchData4()
}, [ }, [
setEnableFileManager, setEnableFileManager,
setIsDisableModelSwitch, setIsDisableModelSwitch,
setIsEnableAutoSavingState, setIsEnableAutoSavingState,
setIsControlNet,
]) ])
// Dark Mode Hotkey // Dark Mode Hotkey

View File

@ -87,6 +87,12 @@ export default async function inpaint(
fd.append('p2pImageGuidanceScale', settings.p2pImageGuidanceScale.toString()) fd.append('p2pImageGuidanceScale', settings.p2pImageGuidanceScale.toString())
fd.append('p2pGuidanceScale', settings.p2pGuidanceScale.toString()) fd.append('p2pGuidanceScale', settings.p2pGuidanceScale.toString())
// ControlNet
fd.append(
'controlnet_conditioning_scale',
settings.controlnetConditioningScale.toString()
)
if (sizeLimit === undefined) { if (sizeLimit === undefined) {
fd.append('sizeLimit', '1080') fd.append('sizeLimit', '1080')
} else { } else {
@ -116,6 +122,12 @@ export function getIsDisableModelSwitch() {
}) })
} }
export function getIsControlNet() {
return fetch(`${API_ENDPOINT}/is_controlnet`, {
method: 'GET',
})
}
export function getEnableFileManager() { export function getEnableFileManager() {
return fetch(`${API_ENDPOINT}/is_enable_file_manager`, { return fetch(`${API_ENDPOINT}/is_enable_file_manager`, {
method: 'GET', method: 'GET',

View File

@ -3,6 +3,7 @@ import { useRecoilState, useRecoilValue } from 'recoil'
import * as PopoverPrimitive from '@radix-ui/react-popover' import * as PopoverPrimitive from '@radix-ui/react-popover'
import { useToggle } from 'react-use' import { useToggle } from 'react-use'
import { import {
isControlNetState,
isInpaintingState, isInpaintingState,
negativePropmtState, negativePropmtState,
propmtState, propmtState,
@ -26,6 +27,7 @@ const SidePanel = () => {
useRecoilState(negativePropmtState) useRecoilState(negativePropmtState)
const isInpainting = useRecoilValue(isInpaintingState) const isInpainting = useRecoilValue(isInpaintingState)
const prompt = useRecoilValue(propmtState) const prompt = useRecoilValue(propmtState)
const isControlNet = useRecoilValue(isControlNetState)
const handleOnInput = (evt: FormEvent<HTMLTextAreaElement>) => { const handleOnInput = (evt: FormEvent<HTMLTextAreaElement>) => {
evt.preventDefault() evt.preventDefault()
@ -115,6 +117,22 @@ const SidePanel = () => {
}} }}
/> />
{isControlNet && (
<NumberInputSetting
title="ControlNet Weight"
width={INPUT_WIDTH}
allowFloat
value={`${setting.controlnetConditioningScale}`}
desc="Lowered this value if there is a big misalignment between the text prompt and the control image"
onValue={value => {
const val = value.length === 0 ? 0 : parseFloat(value)
setSettingState(old => {
return { ...old, controlnetConditioningScale: val }
})
}}
/>
)}
<NumberInputSetting <NumberInputSetting
title="Mask Blur" title="Mask Blur"
width={INPUT_WIDTH} width={INPUT_WIDTH}

View File

@ -51,6 +51,7 @@ interface AppState {
enableFileManager: boolean enableFileManager: boolean
gifImage: HTMLImageElement | undefined gifImage: HTMLImageElement | undefined
brushSize: number brushSize: number
isControlNet: boolean
} }
export const appState = atom<AppState>({ export const appState = atom<AppState>({
@ -70,6 +71,7 @@ export const appState = atom<AppState>({
enableFileManager: false, enableFileManager: false,
gifImage: undefined, gifImage: undefined,
brushSize: 40, brushSize: 40,
isControlNet: false,
}, },
}) })
@ -239,6 +241,18 @@ export const isDisableModelSwitchState = selector({
}, },
}) })
export const isControlNetState = selector({
key: 'isControlNetState',
get: ({ get }) => {
const app = get(appState)
return app.isControlNet
},
set: ({ get, set }, newValue: any) => {
const app = get(appState)
set(appState, { ...app, isControlNet: newValue })
},
})
export const isEnableAutoSavingState = selector({ export const isEnableAutoSavingState = selector({
key: 'isEnableAutoSavingState', key: 'isEnableAutoSavingState',
get: ({ get }) => { get: ({ get }) => {
@ -379,6 +393,9 @@ export interface Settings {
p2pSteps: number p2pSteps: number
p2pImageGuidanceScale: number p2pImageGuidanceScale: number
p2pGuidanceScale: number p2pGuidanceScale: number
// ControlNet
controlnetConditioningScale: number
} }
const defaultHDSettings: ModelsHDSettings = { const defaultHDSettings: ModelsHDSettings = {
@ -482,6 +499,7 @@ export enum SDSampler {
kEuler = 'k_euler', kEuler = 'k_euler',
kEulerA = 'k_euler_a', kEulerA = 'k_euler_a',
dpmPlusPlus = 'dpm++', dpmPlusPlus = 'dpm++',
uni_pc = 'uni_pc',
} }
export enum SDMode { export enum SDMode {
@ -510,7 +528,7 @@ export const settingStateDefault: Settings = {
sdStrength: 0.75, sdStrength: 0.75,
sdSteps: 50, sdSteps: 50,
sdGuidanceScale: 7.5, sdGuidanceScale: 7.5,
sdSampler: SDSampler.pndm, sdSampler: SDSampler.uni_pc,
sdSeed: 42, sdSeed: 42,
sdSeedFixed: false, sdSeedFixed: false,
sdNumSamples: 1, sdNumSamples: 1,
@ -533,6 +551,9 @@ export const settingStateDefault: Settings = {
p2pSteps: 50, p2pSteps: 50,
p2pImageGuidanceScale: 1.5, p2pImageGuidanceScale: 1.5,
p2pGuidanceScale: 7.5, p2pGuidanceScale: 7.5,
// ControlNet
controlnetConditioningScale: 0.4,
} }
const localStorageEffect = const localStorageEffect =

View File

@ -6,7 +6,8 @@ MPS_SUPPORT_MODELS = [
"anything4", "anything4",
"realisticVision1.4", "realisticVision1.4",
"sd2", "sd2",
"paint_by_example" "paint_by_example",
"controlnet",
] ]
DEFAULT_MODEL = "lama" DEFAULT_MODEL = "lama"
@ -24,10 +25,12 @@ AVAILABLE_MODELS = [
"sd2", "sd2",
"paint_by_example", "paint_by_example",
"instruct_pix2pix", "instruct_pix2pix",
"controlnet",
] ]
SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"]
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"] AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
DEFAULT_DEVICE = 'cuda' DEFAULT_DEVICE = "cuda"
NO_HALF_HELP = """ NO_HALF_HELP = """
Using full precision model. Using full precision model.
@ -46,6 +49,10 @@ SD_CPU_TEXTENCODER_HELP = """
Run Stable Diffusion text encoder model on CPU to save GPU memory. Run Stable Diffusion text encoder model on CPU to save GPU memory.
""" """
SD_CONTROLNET_HELP = """
Run Stable Diffusion 1.5 inpainting model with controlNet-canny model.
"""
LOCAL_FILES_ONLY_HELP = """ LOCAL_FILES_ONLY_HELP = """
Use local files only, not connect to Hugging Face server. (sd/paint_by_example) Use local files only, not connect to Hugging Face server. (sd/paint_by_example)
""" """
@ -55,8 +62,7 @@ Enable xFormers optimizations. Requires xformers package has been installed. See
""" """
DEFAULT_MODEL_DIR = os.getenv( DEFAULT_MODEL_DIR = os.getenv(
"XDG_CACHE_HOME", "XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache")
os.path.join(os.path.expanduser("~"), ".cache")
) )
MODEL_DIR_HELP = """ MODEL_DIR_HELP = """
Model download directory (by setting XDG_CACHE_HOME environment variable), by default model downloaded to ~/.cache Model download directory (by setting XDG_CACHE_HOME environment variable), by default model downloaded to ~/.cache

View File

@ -0,0 +1,157 @@
import PIL.Image
import cv2
import numpy as np
import torch
from diffusers import (
ControlNetModel,
)
from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import torch_gc, get_scheduler
from lama_cleaner.schema import Config
class CPUTextEncoderWrapper:
def __init__(self, text_encoder, torch_dtype):
self.config = text_encoder.config
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
self.torch_dtype = torch_dtype
del text_encoder
torch_gc()
def __call__(self, x, **kwargs):
input_device = x.device
return [
self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0]
.to(input_device)
.to(self.torch_dtype)
]
@property
def dtype(self):
return self.torch_dtype
NAMES_MAP = {
"sd1.5": "runwayml/stable-diffusion-inpainting",
"anything4": "Sanster/anything-4.0-inpainting",
"realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting",
}
class ControlNet(DiffusionInpaintModel):
name = "controlnet"
pad_mod = 8
min_size = 512
def init_model(self, device: torch.device, **kwargs):
from .pipeline import StableDiffusionControlNetInpaintPipeline
model_id = NAMES_MAP[kwargs["name"]]
fp16 = not kwargs.get("no_half", False)
model_kwargs = {
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"])
}
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
logger.info("Disable Stable Diffusion Model NSFW checker")
model_kwargs.update(
dict(
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
)
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
controlnet = ControlNetModel.from_pretrained(
f"lllyasviel/sd-controlnet-canny", torch_dtype=torch_dtype
)
self.model = StableDiffusionControlNetInpaintPipeline.from_pretrained(
model_id,
controlnet=controlnet,
revision="fp16" if use_gpu and fp16 else "main",
torch_dtype=torch_dtype,
**model_kwargs,
)
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
self.model.enable_attention_slicing()
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
if kwargs.get("enable_xformers", False):
self.model.enable_xformers_memory_efficient_attention()
if kwargs.get("cpu_offload", False) and use_gpu:
logger.info("Enable sequential cpu offload")
self.model.enable_sequential_cpu_offload(gpu_id=0)
else:
self.model = self.model.to(device)
if kwargs["sd_cpu_textencoder"]:
logger.info("Run Stable Diffusion TextEncoder on CPU")
self.model.text_encoder = CPUTextEncoderWrapper(
self.model.text_encoder, torch_dtype
)
self.callback = kwargs.pop("callback", None)
def forward(self, image, mask, config: Config):
"""Input image and output image have same size
image: [H, W, C] RGB
mask: [H, W, 1] 255 means area to repaint
return: BGR IMAGE
"""
scheduler_config = self.model.scheduler.config
scheduler = get_scheduler(config.sd_sampler, scheduler_config)
self.model.scheduler = scheduler
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
img_h, img_w = image.shape[:2]
canny_image = cv2.Canny(image, 100, 200)
canny_image = canny_image[:, :, None]
canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
canny_image = PIL.Image.fromarray(canny_image)
mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
image = PIL.Image.fromarray(image)
output = self.model(
image=image,
control_image=canny_image,
prompt=config.prompt,
negative_prompt=config.negative_prompt,
mask_image=mask_image,
num_inference_steps=config.sd_steps,
guidance_scale=config.sd_guidance_scale,
output_type="np.array",
callback=self.callback,
height=img_h,
width=img_w,
generator=torch.manual_seed(config.sd_seed),
controlnet_conditioning_scale=config.controlnet_conditioning_scale,
).images[0]
output = (output * 255).round().astype("uint8")
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
return output
def forward_post_process(self, result, image, mask, config):
if config.sd_match_histograms:
result = self._match_histograms(result, image[:, :, ::-1], mask)
if config.sd_mask_blur != 0:
k = 2 * config.sd_mask_blur + 1
mask = cv2.GaussianBlur(mask, (k, k), 0)
return result, image, mask
@staticmethod
def is_downloaded() -> bool:
# model will be downloaded when app start, and can't switch in frontend settings
return True

View File

@ -0,0 +1,3 @@
from .pipeline_stable_diffusion_controlnet_inpaint import (
StableDiffusionControlNetInpaintPipeline,
)

View File

@ -0,0 +1,585 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copy from https://github.com/mikonvergence/ControlNetInpaint/blob/main/src/pipeline_stable_diffusion_controlnet_inpaint.py
import torch
import PIL.Image
import numpy as np
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import *
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> # !pip install opencv-python transformers accelerate
>>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
>>> from diffusers.utils import load_image
>>> import numpy as np
>>> import torch
>>> import cv2
>>> from PIL import Image
>>> # download an image
>>> image = load_image(
... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
... )
>>> image = np.array(image)
>>> mask_image = load_image(
... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
... )
>>> mask_image = np.array(mask_image)
>>> # get canny image
>>> canny_image = cv2.Canny(image, 100, 200)
>>> canny_image = canny_image[:, :, None]
>>> canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
>>> canny_image = Image.fromarray(canny_image)
>>> # load control net and stable diffusion v1-5
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
... )
>>> # speed up diffusion process with faster scheduler and memory optimization
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
>>> # remove following line if xformers is not installed
>>> pipe.enable_xformers_memory_efficient_attention()
>>> pipe.enable_model_cpu_offload()
>>> # generate image
>>> generator = torch.manual_seed(0)
>>> image = pipe(
... "futuristic-looking doggo",
... num_inference_steps=20,
... generator=generator,
... image=image,
... control_image=canny_image,
... mask_image=mask_image
... ).images[0]
```
"""
def prepare_mask_and_masked_image(image, mask):
"""
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
``image`` and ``1`` for the ``mask``.
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
Args:
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
Raises:
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
(ot the other way around).
Returns:
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
dimensions: ``batch x channels x height x width``.
"""
if isinstance(image, torch.Tensor):
if not isinstance(mask, torch.Tensor):
raise TypeError(
f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not"
)
# Batch single image
if image.ndim == 3:
assert (
image.shape[0] == 3
), "Image outside a batch should be of shape (3, H, W)"
image = image.unsqueeze(0)
# Batch and add channel dim for single mask
if mask.ndim == 2:
mask = mask.unsqueeze(0).unsqueeze(0)
# Batch single mask or add channel dim
if mask.ndim == 3:
# Single batched mask, no channel dim or single mask not batched but channel dim
if mask.shape[0] == 1:
mask = mask.unsqueeze(0)
# Batched masks no channel dim
else:
mask = mask.unsqueeze(1)
assert (
image.ndim == 4 and mask.ndim == 4
), "Image and Mask must have 4 dimensions"
assert (
image.shape[-2:] == mask.shape[-2:]
), "Image and Mask must have the same spatial dimensions"
assert (
image.shape[0] == mask.shape[0]
), "Image and Mask must have the same batch size"
# Check image is in [-1, 1]
if image.min() < -1 or image.max() > 1:
raise ValueError("Image should be in [-1, 1] range")
# Check mask is in [0, 1]
if mask.min() < 0 or mask.max() > 1:
raise ValueError("Mask should be in [0, 1] range")
# Binarize mask
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
# Image as float32
image = image.to(dtype=torch.float32)
elif isinstance(mask, torch.Tensor):
raise TypeError(
f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not"
)
else:
# preprocess image
if isinstance(image, (PIL.Image.Image, np.ndarray)):
image = [image]
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
image = [np.array(i.convert("RGB"))[None, :] for i in image]
image = np.concatenate(image, axis=0)
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
image = np.concatenate([i[None, :] for i in image], axis=0)
image = image.transpose(0, 3, 1, 2)
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
# preprocess mask
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
mask = [mask]
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
mask = np.concatenate(
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
)
mask = mask.astype(np.float32) / 255.0
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = image * (mask < 0.5)
return mask, masked_image
class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline):
r"""
Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance.
This model inherits from [`StableDiffusionControlNetPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
controlnet ([`ControlNetModel`]):
Provides additional conditioning to the unet during the denoising process
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
safety_checker ([`StableDiffusionSafetyChecker`]):
Classification module that estimates whether generated images could be considered offensive or harmful.
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
feature_extractor ([`CLIPFeatureExtractor`]):
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
"""
def prepare_mask_latents(
self,
mask,
masked_image,
batch_size,
height,
width,
dtype,
device,
generator,
do_classifier_free_guidance,
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
)
mask = mask.to(device=device, dtype=dtype)
masked_image = masked_image.to(device=device, dtype=dtype)
# encode the mask image into latents space so we can concatenate it to the latents
if isinstance(generator, list):
masked_image_latents = [
self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(
generator=generator[i]
)
for i in range(batch_size)
]
masked_image_latents = torch.cat(masked_image_latents, dim=0)
else:
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(
generator=generator
)
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(
batch_size // masked_image_latents.shape[0], 1, 1, 1
)
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = (
torch.cat([masked_image_latents] * 2)
if do_classifier_free_guidance
else masked_image_latents
)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
control_image: Union[
torch.FloatTensor,
PIL.Image.Image,
List[torch.FloatTensor],
List[PIL.Image.Image],
] = None,
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: float = 1.0,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
be masked out with `mask_image` and repainted according to `prompt`.
control_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
also be accepted as an image. The control image is automatically resized to fit the output image.
mask_image (`PIL.Image.Image`):
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
instead of 3, so the expected shape would be `(B, H, W, 1)`.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original unet.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
# 0. Default height and width to unet
height, width = self._default_height_width(height, width, control_image)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
control_image,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
)
# 4. Prepare image
control_image = self.prepare_image(
control_image,
width,
height,
batch_size * num_images_per_prompt,
num_images_per_prompt,
device,
self.controlnet.dtype,
)
if do_classifier_free_guidance:
control_image = torch.cat([control_image] * 2)
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.controlnet.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# EXTRA: prepare mask latents
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
mask, masked_image_latents = self.prepare_mask_latents(
mask,
masked_image,
batch_size * num_images_per_prompt,
height,
width,
prompt_embeds.dtype,
device,
generator,
do_classifier_free_guidance,
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = (
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
)
latent_model_input = self.scheduler.scale_model_input(
latent_model_input, t
)
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=control_image,
return_dict=False,
)
down_block_res_samples = [
down_block_res_sample * controlnet_conditioning_scale
for down_block_res_sample in down_block_res_samples
]
mid_block_res_sample *= controlnet_conditioning_scale
# predict the noise residual
latent_model_input = torch.cat(
[latent_model_input, mask, masked_image_latents], dim=1
)
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(
noise_pred, t, latents, **extra_step_kwargs
).prev_sample
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
if output_type == "latent":
image = latents
has_nsfw_concept = None
elif output_type == "pil":
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(
image, device, prompt_embeds.dtype
)
# 10. Convert to PIL
image = self.numpy_to_pil(image)
else:
# 8. Post-processing
image = self.decode_latents(latents)
# 9. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(
image, device, prompt_embeds.dtype
)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(
images=image, nsfw_content_detected=has_nsfw_concept
)

View File

@ -1,22 +1,12 @@
import random
import PIL.Image import PIL.Image
import cv2 import cv2
import numpy as np import numpy as np
import torch import torch
from diffusers import (
PNDMScheduler,
DDIMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
)
from loguru import logger from loguru import logger
from lama_cleaner.model.base import DiffusionInpaintModel from lama_cleaner.model.base import DiffusionInpaintModel
from lama_cleaner.model.utils import torch_gc, set_seed from lama_cleaner.model.utils import torch_gc, get_scheduler
from lama_cleaner.schema import Config, SDSampler from lama_cleaner.schema import Config
class CPUTextEncoderWrapper: class CPUTextEncoderWrapper:
@ -101,22 +91,7 @@ class SD(DiffusionInpaintModel):
""" """
scheduler_config = self.model.scheduler.config scheduler_config = self.model.scheduler.config
scheduler = get_scheduler(config.sd_sampler, scheduler_config)
if config.sd_sampler == SDSampler.ddim:
scheduler = DDIMScheduler.from_config(scheduler_config)
elif config.sd_sampler == SDSampler.pndm:
scheduler = PNDMScheduler.from_config(scheduler_config)
elif config.sd_sampler == SDSampler.k_lms:
scheduler = LMSDiscreteScheduler.from_config(scheduler_config)
elif config.sd_sampler == SDSampler.k_euler:
scheduler = EulerDiscreteScheduler.from_config(scheduler_config)
elif config.sd_sampler == SDSampler.k_euler_a:
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config)
elif config.sd_sampler == SDSampler.dpm_plus_plus:
scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config)
else:
raise ValueError(config.sd_sampler)
self.model.scheduler = scheduler self.model.scheduler = scheduler
if config.sd_mask_blur != 0: if config.sd_mask_blur != 0:

View File

@ -7,17 +7,35 @@ import numpy as np
import collections import collections
from itertools import repeat from itertools import repeat
from diffusers import (
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
UniPCMultistepScheduler,
)
from lama_cleaner.schema import SDSampler
from torch import conv2d, conv_transpose2d from torch import conv2d, conv_transpose2d
def make_beta_schedule(device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): def make_beta_schedule(
device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
):
if schedule == "linear": if schedule == "linear":
betas = ( betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 torch.linspace(
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
)
** 2
) )
elif schedule == "cosine": elif schedule == "cosine":
timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s).to(device) timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
).to(device)
alphas = timesteps / (1 + cosine_s) * np.pi / 2 alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2).to(device) alphas = torch.cos(alphas).pow(2).to(device)
alphas = alphas / alphas[0] alphas = alphas / alphas[0]
@ -25,9 +43,14 @@ def make_beta_schedule(device, schedule, n_timestep, linear_start=1e-4, linear_e
betas = np.clip(betas, a_min=0, a_max=0.999) betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear": elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) betas = torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64
)
elif schedule == "sqrt": elif schedule == "sqrt":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 betas = (
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
** 0.5
)
else: else:
raise ValueError(f"schedule '{schedule}' unknown.") raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy() return betas.numpy()
@ -39,33 +62,47 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502 # according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) sigmas = eta * np.sqrt(
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
)
if verbose: if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') print(
print(f'For the chosen value of eta, which is {eta}, ' f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
f'this results in the following sigma_t schedule for ddim sampler {sigmas}') )
print(
f"For the chosen value of eta, which is {eta}, "
f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
)
return sigmas, alphas, alphas_prev return sigmas, alphas, alphas_prev
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): def make_ddim_timesteps(
if ddim_discr_method == 'uniform': ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
):
if ddim_discr_method == "uniform":
c = num_ddpm_timesteps // num_ddim_timesteps c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad': elif ddim_discr_method == "quad":
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) ddim_timesteps = (
(np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
).astype(int)
else: else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps # assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling) # add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1 steps_out = ddim_timesteps + 1
if verbose: if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}') print(f"Selected timesteps for ddim sampler: {steps_out}")
return steps_out return steps_out
def noise_like(shape, device, repeat=False): def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
shape[0], *((1,) * (len(shape) - 1))
)
noise = lambda: torch.randn(shape, device=device) noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise() return repeat_noise() if repeat else noise()
@ -81,7 +118,9 @@ def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=Fal
""" """
half = dim // 2 half = dim // 2
freqs = torch.exp( freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half -math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=device) ).to(device=device)
args = timesteps[:, None].float() * freqs[None] args = timesteps[:, None].float() * freqs[None]
@ -115,9 +154,8 @@ class EasyDict(dict):
del self[name] del self[name]
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): def _bias_act_ref(x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None):
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops. """Slow reference implementation of `bias_act()` using standard TensorFlow ops."""
"""
assert isinstance(x, torch.Tensor) assert isinstance(x, torch.Tensor)
assert clamp is None or clamp >= 0 assert clamp is None or clamp >= 0
spec = activation_funcs[act] spec = activation_funcs[act]
@ -147,7 +185,9 @@ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=N
return x return x
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='ref'): def bias_act(
x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None, impl="ref"
):
r"""Fused bias and activation function. r"""Fused bias and activation function.
Adds bias `b` to activation tensor `x`, evaluates activation function `act`, Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
@ -178,8 +218,10 @@ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None,
Tensor of the same shape and datatype as `x`. Tensor of the same shape and datatype as `x`.
""" """
assert isinstance(x, torch.Tensor) assert isinstance(x, torch.Tensor)
assert impl in ['ref', 'cuda'] assert impl in ["ref", "cuda"]
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) return _bias_act_ref(
x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp
)
def _get_filter_size(f): def _get_filter_size(f):
@ -223,7 +265,14 @@ def _parse_padding(padding):
return padx0, padx1, pady0, pady1 return padx0, padx1, pady0, pady1
def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None): def setup_filter(
f,
device=torch.device("cpu"),
normalize=True,
flip_filter=False,
gain=1,
separable=None,
):
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`. r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
Args: Args:
@ -255,7 +304,7 @@ def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=Fals
# Separable? # Separable?
if separable is None: if separable is None:
separable = (f.ndim == 1 and f.numel() >= 8) separable = f.ndim == 1 and f.numel() >= 8
if f.ndim == 1 and not separable: if f.ndim == 1 and not separable:
f = f.ger(f) f = f.ger(f)
assert f.ndim == (1 if separable else 2) assert f.ndim == (1 if separable else 2)
@ -282,27 +331,82 @@ def _ntuple(n):
to_2tuple = _ntuple(2) to_2tuple = _ntuple(2)
activation_funcs = { activation_funcs = {
'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), "linear": EasyDict(
'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, func=lambda x, **_: x,
ref='y', has_2nd_grad=False), def_alpha=0,
'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=1,
def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), cuda_idx=1,
'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', ref="",
has_2nd_grad=True), has_2nd_grad=False,
'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', ),
has_2nd_grad=True), "relu": EasyDict(
'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', func=lambda x, **_: torch.nn.functional.relu(x),
has_2nd_grad=True), def_alpha=0,
'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', def_gain=np.sqrt(2),
has_2nd_grad=True), cuda_idx=2,
'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref="y",
ref='y', has_2nd_grad=True), has_2nd_grad=False,
'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', ),
has_2nd_grad=True), "lrelu": EasyDict(
func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
def_alpha=0.2,
def_gain=np.sqrt(2),
cuda_idx=3,
ref="y",
has_2nd_grad=False,
),
"tanh": EasyDict(
func=lambda x, **_: torch.tanh(x),
def_alpha=0,
def_gain=1,
cuda_idx=4,
ref="y",
has_2nd_grad=True,
),
"sigmoid": EasyDict(
func=lambda x, **_: torch.sigmoid(x),
def_alpha=0,
def_gain=1,
cuda_idx=5,
ref="y",
has_2nd_grad=True,
),
"elu": EasyDict(
func=lambda x, **_: torch.nn.functional.elu(x),
def_alpha=0,
def_gain=1,
cuda_idx=6,
ref="y",
has_2nd_grad=True,
),
"selu": EasyDict(
func=lambda x, **_: torch.nn.functional.selu(x),
def_alpha=0,
def_gain=1,
cuda_idx=7,
ref="y",
has_2nd_grad=True,
),
"softplus": EasyDict(
func=lambda x, **_: torch.nn.functional.softplus(x),
def_alpha=0,
def_gain=1,
cuda_idx=8,
ref="y",
has_2nd_grad=True,
),
"swish": EasyDict(
func=lambda x, **_: torch.sigmoid(x) * x,
def_alpha=0,
def_gain=np.sqrt(2),
cuda_idx=9,
ref="x",
has_2nd_grad=True,
),
} }
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'): def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"):
r"""Pad, upsample, filter, and downsample a batch of 2D images. r"""Pad, upsample, filter, and downsample a batch of 2D images.
Performs the following sequence of operations for each channel: Performs the following sequence of operations for each channel:
@ -344,12 +448,13 @@ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cu
""" """
# assert isinstance(x, torch.Tensor) # assert isinstance(x, torch.Tensor)
# assert impl in ['ref', 'cuda'] # assert impl in ['ref', 'cuda']
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain) return _upfirdn2d_ref(
x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain
)
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1): def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops. """Slow reference implementation of `upfirdn2d()` using standard PyTorch ops."""
"""
# Validate arguments. # Validate arguments.
assert isinstance(x, torch.Tensor) and x.ndim == 4 assert isinstance(x, torch.Tensor) and x.ndim == 4
if f is None: if f is None:
@ -372,8 +477,15 @@ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx]) x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
# Pad or crop. # Pad or crop.
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]) x = torch.nn.functional.pad(
x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)] x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]
)
x = x[
:,
:,
max(-pady0, 0) : x.shape[2] - max(-pady1, 0),
max(-padx0, 0) : x.shape[3] - max(-padx1, 0),
]
# Setup filter. # Setup filter.
f = f * (gain ** (f.ndim / 2)) f = f * (gain ** (f.ndim / 2))
@ -394,7 +506,7 @@ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
return x return x
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'): def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
r"""Downsample a batch of 2D images using the given 2D FIR filter. r"""Downsample a batch of 2D images using the given 2D FIR filter.
By default, the result is padded so that its shape is a fraction of the input. By default, the result is padded so that its shape is a fraction of the input.
@ -431,10 +543,12 @@ def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'
pady0 + (fh - downy + 1) // 2, pady0 + (fh - downy + 1) // 2,
pady1 + (fh - downy) // 2, pady1 + (fh - downy) // 2,
] ]
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl) return upfirdn2d(
x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl
)
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'): def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
r"""Upsample a batch of 2D images using the given 2D FIR filter. r"""Upsample a batch of 2D images using the given 2D FIR filter.
By default, the result is padded so that its shape is a multiple of the input. By default, the result is padded so that its shape is a multiple of the input.
@ -471,7 +585,15 @@ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
pady0 + (fh + upy - 1) // 2, pady0 + (fh + upy - 1) // 2,
pady1 + (fh - upy) // 2, pady1 + (fh - upy) // 2,
] ]
return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain * upx * upy, impl=impl) return upfirdn2d(
x,
f,
up=up,
padding=p,
flip_filter=flip_filter,
gain=gain * upx * upy,
impl=impl,
)
class MinibatchStdLayer(torch.nn.Module): class MinibatchStdLayer(torch.nn.Module):
@ -482,13 +604,17 @@ class MinibatchStdLayer(torch.nn.Module):
def forward(self, x): def forward(self, x):
N, C, H, W = x.shape N, C, H, W = x.shape
G = torch.min(torch.as_tensor(self.group_size), G = (
torch.as_tensor(N)) if self.group_size is not None else N torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N))
if self.group_size is not None
else N
)
F = self.num_channels F = self.num_channels
c = C // F c = C // F
y = x.reshape(G, -1, F, c, H, y = x.reshape(
W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c. G, -1, F, c, H, W
) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group. y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
y = y.square().mean(dim=0) # [nFcHW] Calc variance over group. y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group. y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
@ -500,17 +626,24 @@ class MinibatchStdLayer(torch.nn.Module):
class FullyConnectedLayer(torch.nn.Module): class FullyConnectedLayer(torch.nn.Module):
def __init__(self, def __init__(
self,
in_features, # Number of input features. in_features, # Number of input features.
out_features, # Number of output features. out_features, # Number of output features.
bias=True, # Apply additive bias before the activation function? bias=True, # Apply additive bias before the activation function?
activation='linear', # Activation function: 'relu', 'lrelu', etc. activation="linear", # Activation function: 'relu', 'lrelu', etc.
lr_multiplier=1, # Learning rate multiplier. lr_multiplier=1, # Learning rate multiplier.
bias_init=0, # Initial value for the additive bias. bias_init=0, # Initial value for the additive bias.
): ):
super().__init__() super().__init__()
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier) self.weight = torch.nn.Parameter(
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None torch.randn([out_features, in_features]) / lr_multiplier
)
self.bias = (
torch.nn.Parameter(torch.full([out_features], np.float32(bias_init)))
if bias
else None
)
self.activation = activation self.activation = activation
self.weight_gain = lr_multiplier / np.sqrt(in_features) self.weight_gain = lr_multiplier / np.sqrt(in_features)
@ -522,7 +655,7 @@ class FullyConnectedLayer(torch.nn.Module):
if b is not None and self.bias_gain != 1: if b is not None and self.bias_gain != 1:
b = b * self.bias_gain b = b * self.bias_gain
if self.activation == 'linear' and b is not None: if self.activation == "linear" and b is not None:
# out = torch.addmm(b.unsqueeze(0), x, w.t()) # out = torch.addmm(b.unsqueeze(0), x, w.t())
x = x.matmul(w.t()) x = x.matmul(w.t())
out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)]) out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)])
@ -532,22 +665,33 @@ class FullyConnectedLayer(torch.nn.Module):
return out return out
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): def _conv2d_wrapper(
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True
""" ):
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations."""
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
# Flip weight if requested. # Flip weight if requested.
if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). if (
not flip_weight
): # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
w = w.flip([2, 3]) w = w.flip([2, 3])
# Workaround performance pitfall in cuDNN 8.0.5, triggered when using # Workaround performance pitfall in cuDNN 8.0.5, triggered when using
# 1x1 kernel + memory_format=channels_last + less than 64 channels. # 1x1 kernel + memory_format=channels_last + less than 64 channels.
if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: if (
kw == 1
and kh == 1
and stride == 1
and padding in [0, [0, 0], (0, 0)]
and not transpose
):
if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
if out_channels <= 4 and groups == 1: if out_channels <= 4 and groups == 1:
in_shape = x.shape in_shape = x.shape
x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) x = w.squeeze(3).squeeze(2) @ x.reshape(
[in_shape[0], in_channels_per_group, -1]
)
x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
else: else:
x = x.to(memory_format=torch.contiguous_format) x = x.to(memory_format=torch.contiguous_format)
@ -560,7 +704,9 @@ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_w
return op(x, w, stride=stride, padding=padding, groups=groups) return op(x, w, stride=stride, padding=padding, groups=groups)
def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): def conv2d_resample(
x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False
):
r"""2D convolution with optional up/downsampling. r"""2D convolution with optional up/downsampling.
Padding is performed only once at the beginning, not between the operations. Padding is performed only once at the beginning, not between the operations.
@ -587,7 +733,9 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
# Validate arguments. # Validate arguments.
assert isinstance(x, torch.Tensor) and (x.ndim == 4) assert isinstance(x, torch.Tensor) and (x.ndim == 4)
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) assert f is None or (
isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32
)
assert isinstance(up, int) and (up >= 1) assert isinstance(up, int) and (up >= 1)
assert isinstance(down, int) and (down >= 1) assert isinstance(down, int) and (down >= 1)
# assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}" # assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}"
@ -610,20 +758,31 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
if kw == 1 and kh == 1 and (down > 1 and up == 1): if kw == 1 and kh == 1 and (down > 1 and up == 1):
x = upfirdn2d(x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter) x = upfirdn2d(
x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter
)
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
return x return x
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
if kw == 1 and kh == 1 and (up > 1 and down == 1): if kw == 1 and kh == 1 and (up > 1 and down == 1):
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
x = upfirdn2d(x=x, f=f, up=up, padding=[px0, px1, py0, py1], gain=up ** 2, flip_filter=flip_filter) x = upfirdn2d(
x=x,
f=f,
up=up,
padding=[px0, px1, py0, py1],
gain=up**2,
flip_filter=flip_filter,
)
return x return x
# Fast path: downsampling only => use strided convolution. # Fast path: downsampling only => use strided convolution.
if down > 1 and up == 1: if down > 1 and up == 1:
x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter) x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) x = _conv2d_wrapper(
x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight
)
return x return x
# Fast path: upsampling with optional downsampling => use transpose strided convolution. # Fast path: upsampling with optional downsampling => use transpose strided convolution.
@ -633,17 +792,31 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
else: else:
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
w = w.transpose(1, 2) w = w.transpose(1, 2)
w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) w = w.reshape(
groups * in_channels_per_group, out_channels // groups, kh, kw
)
px0 -= kw - 1 px0 -= kw - 1
px1 -= kw - up px1 -= kw - up
py0 -= kh - 1 py0 -= kh - 1
py1 -= kh - up py1 -= kh - up
pxt = max(min(-px0, -px1), 0) pxt = max(min(-px0, -px1), 0)
pyt = max(min(-py0, -py1), 0) pyt = max(min(-py0, -py1), 0)
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt, pxt], groups=groups, transpose=True, x = _conv2d_wrapper(
flip_weight=(not flip_weight)) x=x,
x = upfirdn2d(x=x, f=f, padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], gain=up ** 2, w=w,
flip_filter=flip_filter) stride=up,
padding=[pyt, pxt],
groups=groups,
transpose=True,
flip_weight=(not flip_weight),
)
x = upfirdn2d(
x=x,
f=f,
padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt],
gain=up**2,
flip_filter=flip_filter,
)
if down > 1: if down > 1:
x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
return x return x
@ -651,11 +824,19 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
if up == 1 and down == 1: if up == 1 and down == 1:
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
return _conv2d_wrapper(x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight) return _conv2d_wrapper(
x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight
)
# Fallback: Generic reference implementation. # Fallback: Generic reference implementation.
x = upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0, px1, py0, py1], gain=up ** 2, x = upfirdn2d(
flip_filter=flip_filter) x=x,
f=(f if up > 1 else None),
up=up,
padding=[px0, px1, py0, py1],
gain=up**2,
flip_filter=flip_filter,
)
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
if down > 1: if down > 1:
x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
@ -663,15 +844,21 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
class Conv2dLayer(torch.nn.Module): class Conv2dLayer(torch.nn.Module):
def __init__(self, def __init__(
self,
in_channels, # Number of input channels. in_channels, # Number of input channels.
out_channels, # Number of output channels. out_channels, # Number of output channels.
kernel_size, # Width and height of the convolution kernel. kernel_size, # Width and height of the convolution kernel.
bias=True, # Apply additive bias before the activation function? bias=True, # Apply additive bias before the activation function?
activation='linear', # Activation function: 'relu', 'lrelu', etc. activation="linear", # Activation function: 'relu', 'lrelu', etc.
up=1, # Integer upsampling factor. up=1, # Integer upsampling factor.
down=1, # Integer downsampling factor. down=1, # Integer downsampling factor.
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations. resample_filter=[
1,
3,
3,
1,
], # Low-pass filter to apply when resampling activations.
conv_clamp=None, # Clamp the output to +-X, None = disable clamping. conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
channels_last=False, # Expect the input to have memory_format=channels_last? channels_last=False, # Expect the input to have memory_format=channels_last?
trainable=True, # Update the weights of this layer during training? trainable=True, # Update the weights of this layer during training?
@ -680,33 +867,45 @@ class Conv2dLayer(torch.nn.Module):
self.activation = activation self.activation = activation
self.up = up self.up = up
self.down = down self.down = down
self.register_buffer('resample_filter', setup_filter(resample_filter)) self.register_buffer("resample_filter", setup_filter(resample_filter))
self.conv_clamp = conv_clamp self.conv_clamp = conv_clamp
self.padding = kernel_size // 2 self.padding = kernel_size // 2
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2)) self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
self.act_gain = activation_funcs[activation].def_gain self.act_gain = activation_funcs[activation].def_gain
memory_format = torch.channels_last if channels_last else torch.contiguous_format memory_format = (
weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format) torch.channels_last if channels_last else torch.contiguous_format
)
weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(
memory_format=memory_format
)
bias = torch.zeros([out_channels]) if bias else None bias = torch.zeros([out_channels]) if bias else None
if trainable: if trainable:
self.weight = torch.nn.Parameter(weight) self.weight = torch.nn.Parameter(weight)
self.bias = torch.nn.Parameter(bias) if bias is not None else None self.bias = torch.nn.Parameter(bias) if bias is not None else None
else: else:
self.register_buffer('weight', weight) self.register_buffer("weight", weight)
if bias is not None: if bias is not None:
self.register_buffer('bias', bias) self.register_buffer("bias", bias)
else: else:
self.bias = None self.bias = None
def forward(self, x, gain=1): def forward(self, x, gain=1):
w = self.weight * self.weight_gain w = self.weight * self.weight_gain
x = conv2d_resample(x=x, w=w, f=self.resample_filter, up=self.up, down=self.down, x = conv2d_resample(
padding=self.padding) x=x,
w=w,
f=self.resample_filter,
up=self.up,
down=self.down,
padding=self.padding,
)
act_gain = self.act_gain * gain act_gain = self.act_gain * gain
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp) out = bias_act(
x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp
)
return out return out
@ -721,3 +920,22 @@ def set_seed(seed: int):
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
def get_scheduler(sd_sampler, scheduler_config):
if sd_sampler == SDSampler.ddim:
return DDIMScheduler.from_config(scheduler_config)
elif sd_sampler == SDSampler.pndm:
return PNDMScheduler.from_config(scheduler_config)
elif sd_sampler == SDSampler.k_lms:
return LMSDiscreteScheduler.from_config(scheduler_config)
elif sd_sampler == SDSampler.k_euler:
return EulerDiscreteScheduler.from_config(scheduler_config)
elif sd_sampler == SDSampler.k_euler_a:
return EulerAncestralDiscreteScheduler.from_config(scheduler_config)
elif sd_sampler == SDSampler.dpm_plus_plus:
return DPMSolverMultistepScheduler.from_config(scheduler_config)
elif sd_sampler == SDSampler.uni_pc:
return UniPCMultistepScheduler.from_config(scheduler_config)
else:
raise ValueError(sd_sampler)

View File

@ -1,7 +1,9 @@
import torch import torch
import gc import gc
from lama_cleaner.const import SD15_MODELS
from lama_cleaner.helper import switch_mps_device from lama_cleaner.helper import switch_mps_device
from lama_cleaner.model.controlnet import ControlNet
from lama_cleaner.model.fcf import FcF from lama_cleaner.model.fcf import FcF
from lama_cleaner.model.lama import LaMa from lama_cleaner.model.lama import LaMa
from lama_cleaner.model.ldm import LDM from lama_cleaner.model.ldm import LDM
@ -20,7 +22,7 @@ models = {
"zits": ZITS, "zits": ZITS,
"mat": MAT, "mat": MAT,
"fcf": FcF, "fcf": FcF,
"sd1.5": SD15, SD15.name: SD15,
Anything4.name: Anything4, Anything4.name: Anything4,
RealisticVision14.name: RealisticVision14, RealisticVision14.name: RealisticVision14,
"cv2": OpenCV2, "cv2": OpenCV2,
@ -39,6 +41,9 @@ class ModelManager:
self.model = self.init_model(name, device, **kwargs) self.model = self.init_model(name, device, **kwargs)
def init_model(self, name: str, device, **kwargs): def init_model(self, name: str, device, **kwargs):
if name in SD15_MODELS and kwargs.get("sd_controlnet", False):
return ControlNet(device, **{**kwargs, "name": name})
if name in models: if name in models:
model = models[name](device, **kwargs) model = models[name](device, **kwargs)
else: else:

View File

@ -5,25 +5,7 @@ from pathlib import Path
from loguru import logger from loguru import logger
from lama_cleaner.const import ( from lama_cleaner.const import *
AVAILABLE_MODELS,
NO_HALF_HELP,
CPU_OFFLOAD_HELP,
DISABLE_NSFW_HELP,
SD_CPU_TEXTENCODER_HELP,
LOCAL_FILES_ONLY_HELP,
AVAILABLE_DEVICES,
ENABLE_XFORMERS_HELP,
MODEL_DIR_HELP,
OUTPUT_DIR_HELP,
INPUT_HELP,
GUI_HELP,
DEFAULT_DEVICE,
NO_GUI_AUTO_CLOSE_HELP,
DEFAULT_MODEL_DIR,
DEFAULT_MODEL,
MPS_SUPPORT_MODELS,
)
from lama_cleaner.runtime import dump_environment_info from lama_cleaner.runtime import dump_environment_info
@ -55,6 +37,9 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP "--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP
) )
parser.add_argument(
"--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP
)
parser.add_argument( parser.add_argument(
"--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP "--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP
) )
@ -133,6 +118,10 @@ def parse_args():
"torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation" "torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation"
) )
if args.sd_controlnet:
if args.model not in SD15_MODELS:
logger.warning(f"--sd_controlnet only support {SD15_MODELS}")
if args.model_dir and args.model_dir is not None: if args.model_dir and args.model_dir is not None:
if os.path.isfile(args.model_dir): if os.path.isfile(args.model_dir):
parser.error(f"invalid --model-dir: {args.model_dir} is a file") parser.error(f"invalid --model-dir: {args.model_dir} is a file")

View File

@ -24,9 +24,10 @@ class SDSampler(str, Enum):
ddim = "ddim" ddim = "ddim"
pndm = "pndm" pndm = "pndm"
k_lms = "k_lms" k_lms = "k_lms"
k_euler = 'k_euler' k_euler = "k_euler"
k_euler_a = 'k_euler_a' k_euler_a = "k_euler_a"
dpm_plus_plus = 'dpm++' dpm_plus_plus = "dpm++"
uni_pc = "uni_pc"
class Config(BaseModel): class Config(BaseModel):
@ -71,14 +72,14 @@ class Config(BaseModel):
# Higher guidance scale encourages to generate images that are closely linked # Higher guidance scale encourages to generate images that are closely linked
# to the text prompt, usually at the expense of lower image quality. # to the text prompt, usually at the expense of lower image quality.
sd_guidance_scale: float = 7.5 sd_guidance_scale: float = 7.5
sd_sampler: str = SDSampler.ddim sd_sampler: str = SDSampler.uni_pc
# -1 mean random seed # -1 mean random seed
sd_seed: int = 42 sd_seed: int = 42
sd_match_histograms: bool = False sd_match_histograms: bool = False
# Configs for opencv inpainting # Configs for opencv inpainting
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07 # opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
cv2_flag: str = 'INPAINT_NS' cv2_flag: str = "INPAINT_NS"
cv2_radius: int = 4 cv2_radius: int = 4
# Paint by Example # Paint by Example
@ -93,3 +94,6 @@ class Config(BaseModel):
p2p_steps: int = 50 p2p_steps: int = 50
p2p_image_guidance_scale: float = 1.5 p2p_image_guidance_scale: float = 1.5
p2p_guidance_scale: float = 7.5 p2p_guidance_scale: float = 7.5
# ControlNet
controlnet_conditioning_scale: float = 0.4

View File

@ -18,6 +18,7 @@ import numpy as np
from loguru import logger from loguru import logger
from watchdog.events import FileSystemEventHandler from watchdog.events import FileSystemEventHandler
from lama_cleaner.const import SD15_MODELS
from lama_cleaner.interactive_seg import InteractiveSeg, Click from lama_cleaner.interactive_seg import InteractiveSeg, Click
from lama_cleaner.make_gif import make_compare_gif from lama_cleaner.make_gif import make_compare_gif
from lama_cleaner.model_manager import ModelManager from lama_cleaner.model_manager import ModelManager
@ -88,6 +89,7 @@ interactive_seg_model: InteractiveSeg = None
device = None device = None
input_image_path: str = None input_image_path: str = None
is_disable_model_switch: bool = False is_disable_model_switch: bool = False
is_controlnet: bool = False
is_enable_file_manager: bool = False is_enable_file_manager: bool = False
is_enable_auto_saving: bool = False is_enable_auto_saving: bool = False
is_desktop: bool = False is_desktop: bool = False
@ -257,6 +259,7 @@ def process():
p2p_steps=form["p2pSteps"], p2p_steps=form["p2pSteps"],
p2p_image_guidance_scale=form["p2pImageGuidanceScale"], p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
p2p_guidance_scale=form["p2pGuidanceScale"], p2p_guidance_scale=form["p2pGuidanceScale"],
controlnet_conditioning_scale=form["controlnet_conditioning_scale"],
) )
if config.sd_seed == -1: if config.sd_seed == -1:
@ -347,6 +350,12 @@ def current_model():
return model.name, 200 return model.name, 200
@app.route("/is_controlnet")
def get_is_controlnet():
res = "true" if is_controlnet else "false"
return res, 200
@app.route("/is_disable_model_switch") @app.route("/is_disable_model_switch")
def get_is_disable_model_switch(): def get_is_disable_model_switch():
res = "true" if is_disable_model_switch else "false" res = "true" if is_disable_model_switch else "false"
@ -427,6 +436,9 @@ def main(args):
global thumb global thumb
global output_dir global output_dir
global is_enable_auto_saving global is_enable_auto_saving
global is_controlnet
if args.sd_controlnet and args.model in SD15_MODELS:
is_controlnet = True
output_dir = args.output_dir output_dir = args.output_dir
if output_dir is not None: if output_dir is not None:
@ -464,6 +476,7 @@ def main(args):
model = ModelManager( model = ModelManager(
name=args.model, name=args.model,
sd_controlnet=args.sd_controlnet,
device=device, device=device,
no_half=args.no_half, no_half=args.no_half,
hf_access_token=args.hf_access_token, hf_access_token=args.hf_access_token,

View File

@ -6,25 +6,7 @@ import gradio as gr
from loguru import logger from loguru import logger
from pydantic import BaseModel from pydantic import BaseModel
from lama_cleaner.const import ( from lama_cleaner.const import *
AVAILABLE_MODELS,
AVAILABLE_DEVICES,
CPU_OFFLOAD_HELP,
NO_HALF_HELP,
DISABLE_NSFW_HELP,
SD_CPU_TEXTENCODER_HELP,
LOCAL_FILES_ONLY_HELP,
ENABLE_XFORMERS_HELP,
MODEL_DIR_HELP,
OUTPUT_DIR_HELP,
INPUT_HELP,
GUI_HELP,
DEFAULT_MODEL,
DEFAULT_DEVICE,
NO_GUI_AUTO_CLOSE_HELP,
DEFAULT_MODEL_DIR,
MPS_SUPPORT_MODELS,
)
_config_file = None _config_file = None
@ -33,6 +15,7 @@ class Config(BaseModel):
host: str = "127.0.0.1" host: str = "127.0.0.1"
port: int = 8080 port: int = 8080
model: str = DEFAULT_MODEL model: str = DEFAULT_MODEL
sd_controlnet: bool = False
device: str = DEFAULT_DEVICE device: str = DEFAULT_DEVICE
gui: bool = False gui: bool = False
no_gui_auto_close: bool = False no_gui_auto_close: bool = False
@ -59,6 +42,7 @@ def save_config(
host, host,
port, port,
model, model,
sd_controlnet,
device, device,
gui, gui,
no_gui_auto_close, no_gui_auto_close,
@ -127,6 +111,9 @@ def main(config_file: str):
disable_nsfw = gr.Checkbox( disable_nsfw = gr.Checkbox(
init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}" init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}"
) )
sd_controlnet = gr.Checkbox(
init_config.sd_controlnet, label=f"{SD_CONTROLNET_HELP}"
)
sd_cpu_textencoder = gr.Checkbox( sd_cpu_textencoder = gr.Checkbox(
init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}" init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}"
) )
@ -149,6 +136,7 @@ def main(config_file: str):
host, host,
port, port,
model, model,
sd_controlnet,
device, device,
gui, gui,
no_gui_auto_close, no_gui_auto_close,

View File

@ -12,7 +12,7 @@ pytest
yacs yacs
markupsafe==2.0.1 markupsafe==2.0.1
scikit-image==0.19.3 scikit-image==0.19.3
diffusers[torch]==0.12.1 diffusers[torch]==0.14.0
transformers>=4.25.1 transformers>=4.25.1
watchdog==2.2.1 watchdog==2.2.1
gradio gradio