add controlnet inpainting
This commit is contained in:
parent
61928c9861
commit
5f4c62ac18
@ -1,11 +1,18 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
warnings.simplefilter("ignore", UserWarning)
|
warnings.simplefilter("ignore", UserWarning)
|
||||||
|
|
||||||
from lama_cleaner.parse_args import parse_args
|
from lama_cleaner.parse_args import parse_args
|
||||||
|
|
||||||
|
|
||||||
def entry_point():
|
def entry_point():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
# To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
|
# To make os.environ["XDG_CACHE_HOME"] = args.model_cache_dir works for diffusers
|
||||||
# https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
|
# https://github.com/huggingface/diffusers/blob/be99201a567c1ccd841dc16fb24e88f7f239c187/src/diffusers/utils/constants.py#L18
|
||||||
from lama_cleaner.server import main
|
from lama_cleaner.server import main
|
||||||
|
|
||||||
main(args)
|
main(args)
|
||||||
|
@ -7,6 +7,7 @@ import Workspace from './components/Workspace'
|
|||||||
import {
|
import {
|
||||||
enableFileManagerState,
|
enableFileManagerState,
|
||||||
fileState,
|
fileState,
|
||||||
|
isControlNetState,
|
||||||
isDisableModelSwitchState,
|
isDisableModelSwitchState,
|
||||||
isEnableAutoSavingState,
|
isEnableAutoSavingState,
|
||||||
toastState,
|
toastState,
|
||||||
@ -17,6 +18,7 @@ import useHotKey from './hooks/useHotkey'
|
|||||||
import {
|
import {
|
||||||
getEnableAutoSaving,
|
getEnableAutoSaving,
|
||||||
getEnableFileManager,
|
getEnableFileManager,
|
||||||
|
getIsControlNet,
|
||||||
getIsDisableModelSwitch,
|
getIsDisableModelSwitch,
|
||||||
isDesktop,
|
isDesktop,
|
||||||
} from './adapters/inpainting'
|
} from './adapters/inpainting'
|
||||||
@ -37,6 +39,7 @@ function App() {
|
|||||||
const setIsDisableModelSwitch = useSetRecoilState(isDisableModelSwitchState)
|
const setIsDisableModelSwitch = useSetRecoilState(isDisableModelSwitchState)
|
||||||
const setEnableFileManager = useSetRecoilState(enableFileManagerState)
|
const setEnableFileManager = useSetRecoilState(enableFileManagerState)
|
||||||
const setIsEnableAutoSavingState = useSetRecoilState(isEnableAutoSavingState)
|
const setIsEnableAutoSavingState = useSetRecoilState(isEnableAutoSavingState)
|
||||||
|
const setIsControlNet = useSetRecoilState(isControlNetState)
|
||||||
|
|
||||||
// Set Input Image
|
// Set Input Image
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@ -75,10 +78,17 @@ function App() {
|
|||||||
setIsEnableAutoSavingState(isEnabled === 'true')
|
setIsEnableAutoSavingState(isEnabled === 'true')
|
||||||
}
|
}
|
||||||
fetchData3()
|
fetchData3()
|
||||||
|
|
||||||
|
const fetchData4 = async () => {
|
||||||
|
const isEnabled = await getIsControlNet().then(res => res.text())
|
||||||
|
setIsControlNet(isEnabled === 'true')
|
||||||
|
}
|
||||||
|
fetchData4()
|
||||||
}, [
|
}, [
|
||||||
setEnableFileManager,
|
setEnableFileManager,
|
||||||
setIsDisableModelSwitch,
|
setIsDisableModelSwitch,
|
||||||
setIsEnableAutoSavingState,
|
setIsEnableAutoSavingState,
|
||||||
|
setIsControlNet,
|
||||||
])
|
])
|
||||||
|
|
||||||
// Dark Mode Hotkey
|
// Dark Mode Hotkey
|
||||||
|
@ -87,6 +87,12 @@ export default async function inpaint(
|
|||||||
fd.append('p2pImageGuidanceScale', settings.p2pImageGuidanceScale.toString())
|
fd.append('p2pImageGuidanceScale', settings.p2pImageGuidanceScale.toString())
|
||||||
fd.append('p2pGuidanceScale', settings.p2pGuidanceScale.toString())
|
fd.append('p2pGuidanceScale', settings.p2pGuidanceScale.toString())
|
||||||
|
|
||||||
|
// ControlNet
|
||||||
|
fd.append(
|
||||||
|
'controlnet_conditioning_scale',
|
||||||
|
settings.controlnetConditioningScale.toString()
|
||||||
|
)
|
||||||
|
|
||||||
if (sizeLimit === undefined) {
|
if (sizeLimit === undefined) {
|
||||||
fd.append('sizeLimit', '1080')
|
fd.append('sizeLimit', '1080')
|
||||||
} else {
|
} else {
|
||||||
@ -116,6 +122,12 @@ export function getIsDisableModelSwitch() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getIsControlNet() {
|
||||||
|
return fetch(`${API_ENDPOINT}/is_controlnet`, {
|
||||||
|
method: 'GET',
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
export function getEnableFileManager() {
|
export function getEnableFileManager() {
|
||||||
return fetch(`${API_ENDPOINT}/is_enable_file_manager`, {
|
return fetch(`${API_ENDPOINT}/is_enable_file_manager`, {
|
||||||
method: 'GET',
|
method: 'GET',
|
||||||
|
@ -3,6 +3,7 @@ import { useRecoilState, useRecoilValue } from 'recoil'
|
|||||||
import * as PopoverPrimitive from '@radix-ui/react-popover'
|
import * as PopoverPrimitive from '@radix-ui/react-popover'
|
||||||
import { useToggle } from 'react-use'
|
import { useToggle } from 'react-use'
|
||||||
import {
|
import {
|
||||||
|
isControlNetState,
|
||||||
isInpaintingState,
|
isInpaintingState,
|
||||||
negativePropmtState,
|
negativePropmtState,
|
||||||
propmtState,
|
propmtState,
|
||||||
@ -26,6 +27,7 @@ const SidePanel = () => {
|
|||||||
useRecoilState(negativePropmtState)
|
useRecoilState(negativePropmtState)
|
||||||
const isInpainting = useRecoilValue(isInpaintingState)
|
const isInpainting = useRecoilValue(isInpaintingState)
|
||||||
const prompt = useRecoilValue(propmtState)
|
const prompt = useRecoilValue(propmtState)
|
||||||
|
const isControlNet = useRecoilValue(isControlNetState)
|
||||||
|
|
||||||
const handleOnInput = (evt: FormEvent<HTMLTextAreaElement>) => {
|
const handleOnInput = (evt: FormEvent<HTMLTextAreaElement>) => {
|
||||||
evt.preventDefault()
|
evt.preventDefault()
|
||||||
@ -115,6 +117,22 @@ const SidePanel = () => {
|
|||||||
}}
|
}}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
|
{isControlNet && (
|
||||||
|
<NumberInputSetting
|
||||||
|
title="ControlNet Weight"
|
||||||
|
width={INPUT_WIDTH}
|
||||||
|
allowFloat
|
||||||
|
value={`${setting.controlnetConditioningScale}`}
|
||||||
|
desc="Lowered this value if there is a big misalignment between the text prompt and the control image"
|
||||||
|
onValue={value => {
|
||||||
|
const val = value.length === 0 ? 0 : parseFloat(value)
|
||||||
|
setSettingState(old => {
|
||||||
|
return { ...old, controlnetConditioningScale: val }
|
||||||
|
})
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
|
||||||
<NumberInputSetting
|
<NumberInputSetting
|
||||||
title="Mask Blur"
|
title="Mask Blur"
|
||||||
width={INPUT_WIDTH}
|
width={INPUT_WIDTH}
|
||||||
|
@ -51,6 +51,7 @@ interface AppState {
|
|||||||
enableFileManager: boolean
|
enableFileManager: boolean
|
||||||
gifImage: HTMLImageElement | undefined
|
gifImage: HTMLImageElement | undefined
|
||||||
brushSize: number
|
brushSize: number
|
||||||
|
isControlNet: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
export const appState = atom<AppState>({
|
export const appState = atom<AppState>({
|
||||||
@ -70,6 +71,7 @@ export const appState = atom<AppState>({
|
|||||||
enableFileManager: false,
|
enableFileManager: false,
|
||||||
gifImage: undefined,
|
gifImage: undefined,
|
||||||
brushSize: 40,
|
brushSize: 40,
|
||||||
|
isControlNet: false,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -239,6 +241,18 @@ export const isDisableModelSwitchState = selector({
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
export const isControlNetState = selector({
|
||||||
|
key: 'isControlNetState',
|
||||||
|
get: ({ get }) => {
|
||||||
|
const app = get(appState)
|
||||||
|
return app.isControlNet
|
||||||
|
},
|
||||||
|
set: ({ get, set }, newValue: any) => {
|
||||||
|
const app = get(appState)
|
||||||
|
set(appState, { ...app, isControlNet: newValue })
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
export const isEnableAutoSavingState = selector({
|
export const isEnableAutoSavingState = selector({
|
||||||
key: 'isEnableAutoSavingState',
|
key: 'isEnableAutoSavingState',
|
||||||
get: ({ get }) => {
|
get: ({ get }) => {
|
||||||
@ -379,6 +393,9 @@ export interface Settings {
|
|||||||
p2pSteps: number
|
p2pSteps: number
|
||||||
p2pImageGuidanceScale: number
|
p2pImageGuidanceScale: number
|
||||||
p2pGuidanceScale: number
|
p2pGuidanceScale: number
|
||||||
|
|
||||||
|
// ControlNet
|
||||||
|
controlnetConditioningScale: number
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultHDSettings: ModelsHDSettings = {
|
const defaultHDSettings: ModelsHDSettings = {
|
||||||
@ -482,6 +499,7 @@ export enum SDSampler {
|
|||||||
kEuler = 'k_euler',
|
kEuler = 'k_euler',
|
||||||
kEulerA = 'k_euler_a',
|
kEulerA = 'k_euler_a',
|
||||||
dpmPlusPlus = 'dpm++',
|
dpmPlusPlus = 'dpm++',
|
||||||
|
uni_pc = 'uni_pc',
|
||||||
}
|
}
|
||||||
|
|
||||||
export enum SDMode {
|
export enum SDMode {
|
||||||
@ -510,7 +528,7 @@ export const settingStateDefault: Settings = {
|
|||||||
sdStrength: 0.75,
|
sdStrength: 0.75,
|
||||||
sdSteps: 50,
|
sdSteps: 50,
|
||||||
sdGuidanceScale: 7.5,
|
sdGuidanceScale: 7.5,
|
||||||
sdSampler: SDSampler.pndm,
|
sdSampler: SDSampler.uni_pc,
|
||||||
sdSeed: 42,
|
sdSeed: 42,
|
||||||
sdSeedFixed: false,
|
sdSeedFixed: false,
|
||||||
sdNumSamples: 1,
|
sdNumSamples: 1,
|
||||||
@ -533,6 +551,9 @@ export const settingStateDefault: Settings = {
|
|||||||
p2pSteps: 50,
|
p2pSteps: 50,
|
||||||
p2pImageGuidanceScale: 1.5,
|
p2pImageGuidanceScale: 1.5,
|
||||||
p2pGuidanceScale: 7.5,
|
p2pGuidanceScale: 7.5,
|
||||||
|
|
||||||
|
// ControlNet
|
||||||
|
controlnetConditioningScale: 0.4,
|
||||||
}
|
}
|
||||||
|
|
||||||
const localStorageEffect =
|
const localStorageEffect =
|
||||||
|
@ -6,7 +6,8 @@ MPS_SUPPORT_MODELS = [
|
|||||||
"anything4",
|
"anything4",
|
||||||
"realisticVision1.4",
|
"realisticVision1.4",
|
||||||
"sd2",
|
"sd2",
|
||||||
"paint_by_example"
|
"paint_by_example",
|
||||||
|
"controlnet",
|
||||||
]
|
]
|
||||||
|
|
||||||
DEFAULT_MODEL = "lama"
|
DEFAULT_MODEL = "lama"
|
||||||
@ -24,10 +25,12 @@ AVAILABLE_MODELS = [
|
|||||||
"sd2",
|
"sd2",
|
||||||
"paint_by_example",
|
"paint_by_example",
|
||||||
"instruct_pix2pix",
|
"instruct_pix2pix",
|
||||||
|
"controlnet",
|
||||||
]
|
]
|
||||||
|
SD15_MODELS = ["sd1.5", "anything4", "realisticVision1.4"]
|
||||||
|
|
||||||
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
|
AVAILABLE_DEVICES = ["cuda", "cpu", "mps"]
|
||||||
DEFAULT_DEVICE = 'cuda'
|
DEFAULT_DEVICE = "cuda"
|
||||||
|
|
||||||
NO_HALF_HELP = """
|
NO_HALF_HELP = """
|
||||||
Using full precision model.
|
Using full precision model.
|
||||||
@ -46,6 +49,10 @@ SD_CPU_TEXTENCODER_HELP = """
|
|||||||
Run Stable Diffusion text encoder model on CPU to save GPU memory.
|
Run Stable Diffusion text encoder model on CPU to save GPU memory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
SD_CONTROLNET_HELP = """
|
||||||
|
Run Stable Diffusion 1.5 inpainting model with controlNet-canny model.
|
||||||
|
"""
|
||||||
|
|
||||||
LOCAL_FILES_ONLY_HELP = """
|
LOCAL_FILES_ONLY_HELP = """
|
||||||
Use local files only, not connect to Hugging Face server. (sd/paint_by_example)
|
Use local files only, not connect to Hugging Face server. (sd/paint_by_example)
|
||||||
"""
|
"""
|
||||||
@ -55,8 +62,7 @@ Enable xFormers optimizations. Requires xformers package has been installed. See
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
DEFAULT_MODEL_DIR = os.getenv(
|
DEFAULT_MODEL_DIR = os.getenv(
|
||||||
"XDG_CACHE_HOME",
|
"XDG_CACHE_HOME", os.path.join(os.path.expanduser("~"), ".cache")
|
||||||
os.path.join(os.path.expanduser("~"), ".cache")
|
|
||||||
)
|
)
|
||||||
MODEL_DIR_HELP = """
|
MODEL_DIR_HELP = """
|
||||||
Model download directory (by setting XDG_CACHE_HOME environment variable), by default model downloaded to ~/.cache
|
Model download directory (by setting XDG_CACHE_HOME environment variable), by default model downloaded to ~/.cache
|
||||||
|
157
lama_cleaner/model/controlnet.py
Normal file
157
lama_cleaner/model/controlnet.py
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
import PIL.Image
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from diffusers import (
|
||||||
|
ControlNetModel,
|
||||||
|
)
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
|
from lama_cleaner.model.utils import torch_gc, get_scheduler
|
||||||
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
|
|
||||||
|
class CPUTextEncoderWrapper:
|
||||||
|
def __init__(self, text_encoder, torch_dtype):
|
||||||
|
self.config = text_encoder.config
|
||||||
|
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True)
|
||||||
|
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True)
|
||||||
|
self.torch_dtype = torch_dtype
|
||||||
|
del text_encoder
|
||||||
|
torch_gc()
|
||||||
|
|
||||||
|
def __call__(self, x, **kwargs):
|
||||||
|
input_device = x.device
|
||||||
|
return [
|
||||||
|
self.text_encoder(x.to(self.text_encoder.device), **kwargs)[0]
|
||||||
|
.to(input_device)
|
||||||
|
.to(self.torch_dtype)
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self.torch_dtype
|
||||||
|
|
||||||
|
|
||||||
|
NAMES_MAP = {
|
||||||
|
"sd1.5": "runwayml/stable-diffusion-inpainting",
|
||||||
|
"anything4": "Sanster/anything-4.0-inpainting",
|
||||||
|
"realisticVision1.4": "Sanster/Realistic_Vision_V1.4-inpainting",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNet(DiffusionInpaintModel):
|
||||||
|
name = "controlnet"
|
||||||
|
pad_mod = 8
|
||||||
|
min_size = 512
|
||||||
|
|
||||||
|
def init_model(self, device: torch.device, **kwargs):
|
||||||
|
from .pipeline import StableDiffusionControlNetInpaintPipeline
|
||||||
|
|
||||||
|
model_id = NAMES_MAP[kwargs["name"]]
|
||||||
|
fp16 = not kwargs.get("no_half", False)
|
||||||
|
|
||||||
|
model_kwargs = {
|
||||||
|
"local_files_only": kwargs.get("local_files_only", kwargs["sd_run_local"])
|
||||||
|
}
|
||||||
|
if kwargs["disable_nsfw"] or kwargs.get("cpu_offload", False):
|
||||||
|
logger.info("Disable Stable Diffusion Model NSFW checker")
|
||||||
|
model_kwargs.update(
|
||||||
|
dict(
|
||||||
|
safety_checker=None,
|
||||||
|
feature_extractor=None,
|
||||||
|
requires_safety_checker=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
use_gpu = device == torch.device("cuda") and torch.cuda.is_available()
|
||||||
|
torch_dtype = torch.float16 if use_gpu and fp16 else torch.float32
|
||||||
|
|
||||||
|
controlnet = ControlNetModel.from_pretrained(
|
||||||
|
f"lllyasviel/sd-controlnet-canny", torch_dtype=torch_dtype
|
||||||
|
)
|
||||||
|
self.model = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
controlnet=controlnet,
|
||||||
|
revision="fp16" if use_gpu and fp16 else "main",
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# https://huggingface.co/docs/diffusers/v0.7.0/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionInpaintPipeline.enable_attention_slicing
|
||||||
|
self.model.enable_attention_slicing()
|
||||||
|
# https://huggingface.co/docs/diffusers/v0.7.0/en/optimization/fp16#memory-efficient-attention
|
||||||
|
if kwargs.get("enable_xformers", False):
|
||||||
|
self.model.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
if kwargs.get("cpu_offload", False) and use_gpu:
|
||||||
|
logger.info("Enable sequential cpu offload")
|
||||||
|
self.model.enable_sequential_cpu_offload(gpu_id=0)
|
||||||
|
else:
|
||||||
|
self.model = self.model.to(device)
|
||||||
|
if kwargs["sd_cpu_textencoder"]:
|
||||||
|
logger.info("Run Stable Diffusion TextEncoder on CPU")
|
||||||
|
self.model.text_encoder = CPUTextEncoderWrapper(
|
||||||
|
self.model.text_encoder, torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
self.callback = kwargs.pop("callback", None)
|
||||||
|
|
||||||
|
def forward(self, image, mask, config: Config):
|
||||||
|
"""Input image and output image have same size
|
||||||
|
image: [H, W, C] RGB
|
||||||
|
mask: [H, W, 1] 255 means area to repaint
|
||||||
|
return: BGR IMAGE
|
||||||
|
"""
|
||||||
|
|
||||||
|
scheduler_config = self.model.scheduler.config
|
||||||
|
scheduler = get_scheduler(config.sd_sampler, scheduler_config)
|
||||||
|
self.model.scheduler = scheduler
|
||||||
|
|
||||||
|
if config.sd_mask_blur != 0:
|
||||||
|
k = 2 * config.sd_mask_blur + 1
|
||||||
|
mask = cv2.GaussianBlur(mask, (k, k), 0)[:, :, np.newaxis]
|
||||||
|
|
||||||
|
img_h, img_w = image.shape[:2]
|
||||||
|
|
||||||
|
canny_image = cv2.Canny(image, 100, 200)
|
||||||
|
canny_image = canny_image[:, :, None]
|
||||||
|
canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
|
||||||
|
canny_image = PIL.Image.fromarray(canny_image)
|
||||||
|
mask_image = PIL.Image.fromarray(mask[:, :, -1], mode="L")
|
||||||
|
image = PIL.Image.fromarray(image)
|
||||||
|
|
||||||
|
output = self.model(
|
||||||
|
image=image,
|
||||||
|
control_image=canny_image,
|
||||||
|
prompt=config.prompt,
|
||||||
|
negative_prompt=config.negative_prompt,
|
||||||
|
mask_image=mask_image,
|
||||||
|
num_inference_steps=config.sd_steps,
|
||||||
|
guidance_scale=config.sd_guidance_scale,
|
||||||
|
output_type="np.array",
|
||||||
|
callback=self.callback,
|
||||||
|
height=img_h,
|
||||||
|
width=img_w,
|
||||||
|
generator=torch.manual_seed(config.sd_seed),
|
||||||
|
controlnet_conditioning_scale=config.controlnet_conditioning_scale,
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
output = (output * 255).round().astype("uint8")
|
||||||
|
output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def forward_post_process(self, result, image, mask, config):
|
||||||
|
if config.sd_match_histograms:
|
||||||
|
result = self._match_histograms(result, image[:, :, ::-1], mask)
|
||||||
|
|
||||||
|
if config.sd_mask_blur != 0:
|
||||||
|
k = 2 * config.sd_mask_blur + 1
|
||||||
|
mask = cv2.GaussianBlur(mask, (k, k), 0)
|
||||||
|
return result, image, mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_downloaded() -> bool:
|
||||||
|
# model will be downloaded when app start, and can't switch in frontend settings
|
||||||
|
return True
|
3
lama_cleaner/model/pipeline/__init__.py
Normal file
3
lama_cleaner/model/pipeline/__init__.py
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
from .pipeline_stable_diffusion_controlnet_inpaint import (
|
||||||
|
StableDiffusionControlNetInpaintPipeline,
|
||||||
|
)
|
@ -0,0 +1,585 @@
|
|||||||
|
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
# Copy from https://github.com/mikonvergence/ControlNetInpaint/blob/main/src/pipeline_stable_diffusion_controlnet_inpaint.py
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import PIL.Image
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import *
|
||||||
|
|
||||||
|
EXAMPLE_DOC_STRING = """
|
||||||
|
Examples:
|
||||||
|
```py
|
||||||
|
>>> # !pip install opencv-python transformers accelerate
|
||||||
|
>>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
|
||||||
|
>>> from diffusers.utils import load_image
|
||||||
|
>>> import numpy as np
|
||||||
|
>>> import torch
|
||||||
|
|
||||||
|
>>> import cv2
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> # download an image
|
||||||
|
>>> image = load_image(
|
||||||
|
... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
|
||||||
|
... )
|
||||||
|
>>> image = np.array(image)
|
||||||
|
>>> mask_image = load_image(
|
||||||
|
... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
|
||||||
|
... )
|
||||||
|
>>> mask_image = np.array(mask_image)
|
||||||
|
>>> # get canny image
|
||||||
|
>>> canny_image = cv2.Canny(image, 100, 200)
|
||||||
|
>>> canny_image = canny_image[:, :, None]
|
||||||
|
>>> canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
|
||||||
|
>>> canny_image = Image.fromarray(canny_image)
|
||||||
|
|
||||||
|
>>> # load control net and stable diffusion v1-5
|
||||||
|
>>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
|
||||||
|
>>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
||||||
|
... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # speed up diffusion process with faster scheduler and memory optimization
|
||||||
|
>>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
||||||
|
>>> # remove following line if xformers is not installed
|
||||||
|
>>> pipe.enable_xformers_memory_efficient_attention()
|
||||||
|
|
||||||
|
>>> pipe.enable_model_cpu_offload()
|
||||||
|
|
||||||
|
>>> # generate image
|
||||||
|
>>> generator = torch.manual_seed(0)
|
||||||
|
>>> image = pipe(
|
||||||
|
... "futuristic-looking doggo",
|
||||||
|
... num_inference_steps=20,
|
||||||
|
... generator=generator,
|
||||||
|
... image=image,
|
||||||
|
... control_image=canny_image,
|
||||||
|
... mask_image=mask_image
|
||||||
|
... ).images[0]
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_mask_and_masked_image(image, mask):
|
||||||
|
"""
|
||||||
|
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
|
||||||
|
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
|
||||||
|
``image`` and ``1`` for the ``mask``.
|
||||||
|
The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
|
||||||
|
binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
|
||||||
|
Args:
|
||||||
|
image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
|
||||||
|
It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
|
||||||
|
``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
|
||||||
|
mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
|
||||||
|
It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
|
||||||
|
``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
|
||||||
|
Raises:
|
||||||
|
ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
|
||||||
|
should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
|
||||||
|
TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
|
||||||
|
(ot the other way around).
|
||||||
|
Returns:
|
||||||
|
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
|
||||||
|
dimensions: ``batch x channels x height x width``.
|
||||||
|
"""
|
||||||
|
if isinstance(image, torch.Tensor):
|
||||||
|
if not isinstance(mask, torch.Tensor):
|
||||||
|
raise TypeError(
|
||||||
|
f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Batch single image
|
||||||
|
if image.ndim == 3:
|
||||||
|
assert (
|
||||||
|
image.shape[0] == 3
|
||||||
|
), "Image outside a batch should be of shape (3, H, W)"
|
||||||
|
image = image.unsqueeze(0)
|
||||||
|
|
||||||
|
# Batch and add channel dim for single mask
|
||||||
|
if mask.ndim == 2:
|
||||||
|
mask = mask.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
# Batch single mask or add channel dim
|
||||||
|
if mask.ndim == 3:
|
||||||
|
# Single batched mask, no channel dim or single mask not batched but channel dim
|
||||||
|
if mask.shape[0] == 1:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
|
||||||
|
# Batched masks no channel dim
|
||||||
|
else:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
image.ndim == 4 and mask.ndim == 4
|
||||||
|
), "Image and Mask must have 4 dimensions"
|
||||||
|
assert (
|
||||||
|
image.shape[-2:] == mask.shape[-2:]
|
||||||
|
), "Image and Mask must have the same spatial dimensions"
|
||||||
|
assert (
|
||||||
|
image.shape[0] == mask.shape[0]
|
||||||
|
), "Image and Mask must have the same batch size"
|
||||||
|
|
||||||
|
# Check image is in [-1, 1]
|
||||||
|
if image.min() < -1 or image.max() > 1:
|
||||||
|
raise ValueError("Image should be in [-1, 1] range")
|
||||||
|
|
||||||
|
# Check mask is in [0, 1]
|
||||||
|
if mask.min() < 0 or mask.max() > 1:
|
||||||
|
raise ValueError("Mask should be in [0, 1] range")
|
||||||
|
|
||||||
|
# Binarize mask
|
||||||
|
mask[mask < 0.5] = 0
|
||||||
|
mask[mask >= 0.5] = 1
|
||||||
|
|
||||||
|
# Image as float32
|
||||||
|
image = image.to(dtype=torch.float32)
|
||||||
|
elif isinstance(mask, torch.Tensor):
|
||||||
|
raise TypeError(
|
||||||
|
f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# preprocess image
|
||||||
|
if isinstance(image, (PIL.Image.Image, np.ndarray)):
|
||||||
|
image = [image]
|
||||||
|
|
||||||
|
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
|
||||||
|
image = [np.array(i.convert("RGB"))[None, :] for i in image]
|
||||||
|
image = np.concatenate(image, axis=0)
|
||||||
|
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
|
||||||
|
image = np.concatenate([i[None, :] for i in image], axis=0)
|
||||||
|
|
||||||
|
image = image.transpose(0, 3, 1, 2)
|
||||||
|
image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
|
||||||
|
|
||||||
|
# preprocess mask
|
||||||
|
if isinstance(mask, (PIL.Image.Image, np.ndarray)):
|
||||||
|
mask = [mask]
|
||||||
|
|
||||||
|
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
|
||||||
|
mask = np.concatenate(
|
||||||
|
[np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
|
||||||
|
)
|
||||||
|
mask = mask.astype(np.float32) / 255.0
|
||||||
|
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
|
||||||
|
mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
|
||||||
|
|
||||||
|
mask[mask < 0.5] = 0
|
||||||
|
mask[mask >= 0.5] = 1
|
||||||
|
mask = torch.from_numpy(mask)
|
||||||
|
|
||||||
|
masked_image = image * (mask < 0.5)
|
||||||
|
|
||||||
|
return mask, masked_image
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline):
|
||||||
|
r"""
|
||||||
|
Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance.
|
||||||
|
|
||||||
|
This model inherits from [`StableDiffusionControlNetPipeline`]. Check the superclass documentation for the generic methods the
|
||||||
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vae ([`AutoencoderKL`]):
|
||||||
|
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||||
|
text_encoder ([`CLIPTextModel`]):
|
||||||
|
Frozen text-encoder. Stable Diffusion uses the text portion of
|
||||||
|
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||||
|
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
|
||||||
|
tokenizer (`CLIPTokenizer`):
|
||||||
|
Tokenizer of class
|
||||||
|
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||||
|
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
|
||||||
|
controlnet ([`ControlNetModel`]):
|
||||||
|
Provides additional conditioning to the unet during the denoising process
|
||||||
|
scheduler ([`SchedulerMixin`]):
|
||||||
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
|
||||||
|
[`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
|
||||||
|
safety_checker ([`StableDiffusionSafetyChecker`]):
|
||||||
|
Classification module that estimates whether generated images could be considered offensive or harmful.
|
||||||
|
Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
|
||||||
|
feature_extractor ([`CLIPFeatureExtractor`]):
|
||||||
|
Model that extracts features from generated images to be used as inputs for the `safety_checker`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def prepare_mask_latents(
|
||||||
|
self,
|
||||||
|
mask,
|
||||||
|
masked_image,
|
||||||
|
batch_size,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
):
|
||||||
|
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||||
|
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||||
|
# and half precision
|
||||||
|
mask = torch.nn.functional.interpolate(
|
||||||
|
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||||
|
)
|
||||||
|
mask = mask.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
# encode the mask image into latents space so we can concatenate it to the latents
|
||||||
|
if isinstance(generator, list):
|
||||||
|
masked_image_latents = [
|
||||||
|
self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(
|
||||||
|
generator=generator[i]
|
||||||
|
)
|
||||||
|
for i in range(batch_size)
|
||||||
|
]
|
||||||
|
masked_image_latents = torch.cat(masked_image_latents, dim=0)
|
||||||
|
else:
|
||||||
|
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(
|
||||||
|
generator=generator
|
||||||
|
)
|
||||||
|
masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
|
||||||
|
|
||||||
|
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||||
|
if mask.shape[0] < batch_size:
|
||||||
|
if not batch_size % mask.shape[0] == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||||
|
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||||
|
" of masks that you pass is divisible by the total requested batch size."
|
||||||
|
)
|
||||||
|
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||||
|
if masked_image_latents.shape[0] < batch_size:
|
||||||
|
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||||
|
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||||
|
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||||
|
)
|
||||||
|
masked_image_latents = masked_image_latents.repeat(
|
||||||
|
batch_size // masked_image_latents.shape[0], 1, 1, 1
|
||||||
|
)
|
||||||
|
|
||||||
|
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
|
||||||
|
masked_image_latents = (
|
||||||
|
torch.cat([masked_image_latents] * 2)
|
||||||
|
if do_classifier_free_guidance
|
||||||
|
else masked_image_latents
|
||||||
|
)
|
||||||
|
|
||||||
|
# aligning device to prevent device errors when concating it with the latent model input
|
||||||
|
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
||||||
|
return mask, masked_image_latents
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
prompt: Union[str, List[str]] = None,
|
||||||
|
image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||||
|
control_image: Union[
|
||||||
|
torch.FloatTensor,
|
||||||
|
PIL.Image.Image,
|
||||||
|
List[torch.FloatTensor],
|
||||||
|
List[PIL.Image.Image],
|
||||||
|
] = None,
|
||||||
|
mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
|
||||||
|
height: Optional[int] = None,
|
||||||
|
width: Optional[int] = None,
|
||||||
|
num_inference_steps: int = 50,
|
||||||
|
guidance_scale: float = 7.5,
|
||||||
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||||
|
num_images_per_prompt: Optional[int] = 1,
|
||||||
|
eta: float = 0.0,
|
||||||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||||
|
latents: Optional[torch.FloatTensor] = None,
|
||||||
|
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
output_type: Optional[str] = "pil",
|
||||||
|
return_dict: bool = True,
|
||||||
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||||
|
callback_steps: int = 1,
|
||||||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
controlnet_conditioning_scale: float = 1.0,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Function invoked when calling the pipeline for generation.
|
||||||
|
Args:
|
||||||
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||||
|
instead.
|
||||||
|
image (`PIL.Image.Image`):
|
||||||
|
`Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
|
||||||
|
be masked out with `mask_image` and repainted according to `prompt`.
|
||||||
|
control_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
|
||||||
|
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
|
||||||
|
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
|
||||||
|
also be accepted as an image. The control image is automatically resized to fit the output image.
|
||||||
|
mask_image (`PIL.Image.Image`):
|
||||||
|
`Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
|
||||||
|
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
|
||||||
|
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
|
||||||
|
instead of 3, so the expected shape would be `(B, H, W, 1)`.
|
||||||
|
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||||
|
The height in pixels of the generated image.
|
||||||
|
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||||
|
The width in pixels of the generated image.
|
||||||
|
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||||
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
|
expense of slower inference.
|
||||||
|
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||||
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||||
|
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||||
|
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||||
|
usually at the expense of lower image quality.
|
||||||
|
negative_prompt (`str` or `List[str]`, *optional*):
|
||||||
|
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||||
|
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
||||||
|
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
||||||
|
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||||
|
The number of images to generate per prompt.
|
||||||
|
eta (`float`, *optional*, defaults to 0.0):
|
||||||
|
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||||
|
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||||
|
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||||
|
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||||
|
to make generation deterministic.
|
||||||
|
latents (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||||
|
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||||
|
tensor will ge generated by sampling using the supplied random `generator`.
|
||||||
|
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||||
|
provided, text embeddings will be generated from `prompt` input argument.
|
||||||
|
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||||
|
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||||
|
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||||
|
argument.
|
||||||
|
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||||
|
The output format of the generate image. Choose between
|
||||||
|
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||||
|
return_dict (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||||
|
plain tuple.
|
||||||
|
callback (`Callable`, *optional*):
|
||||||
|
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||||
|
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||||
|
callback_steps (`int`, *optional*, defaults to 1):
|
||||||
|
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||||
|
called at every step.
|
||||||
|
cross_attention_kwargs (`dict`, *optional*):
|
||||||
|
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
|
||||||
|
`self.processor` in
|
||||||
|
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||||
|
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||||
|
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||||
|
to the residual in the original unet.
|
||||||
|
Examples:
|
||||||
|
Returns:
|
||||||
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||||
|
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||||
|
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||||
|
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||||
|
(nsfw) content, according to the `safety_checker`.
|
||||||
|
"""
|
||||||
|
# 0. Default height and width to unet
|
||||||
|
height, width = self._default_height_width(height, width, control_image)
|
||||||
|
|
||||||
|
# 1. Check inputs. Raise error if not correct
|
||||||
|
self.check_inputs(
|
||||||
|
prompt,
|
||||||
|
control_image,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
callback_steps,
|
||||||
|
negative_prompt,
|
||||||
|
prompt_embeds,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Define call parameters
|
||||||
|
if prompt is not None and isinstance(prompt, str):
|
||||||
|
batch_size = 1
|
||||||
|
elif prompt is not None and isinstance(prompt, list):
|
||||||
|
batch_size = len(prompt)
|
||||||
|
else:
|
||||||
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
|
device = self._execution_device
|
||||||
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||||
|
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||||
|
# corresponds to doing no classifier free guidance.
|
||||||
|
do_classifier_free_guidance = guidance_scale > 1.0
|
||||||
|
|
||||||
|
# 3. Encode input prompt
|
||||||
|
prompt_embeds = self._encode_prompt(
|
||||||
|
prompt,
|
||||||
|
device,
|
||||||
|
num_images_per_prompt,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
negative_prompt,
|
||||||
|
prompt_embeds=prompt_embeds,
|
||||||
|
negative_prompt_embeds=negative_prompt_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Prepare image
|
||||||
|
control_image = self.prepare_image(
|
||||||
|
control_image,
|
||||||
|
width,
|
||||||
|
height,
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
num_images_per_prompt,
|
||||||
|
device,
|
||||||
|
self.controlnet.dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
control_image = torch.cat([control_image] * 2)
|
||||||
|
|
||||||
|
# 5. Prepare timesteps
|
||||||
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||||
|
timesteps = self.scheduler.timesteps
|
||||||
|
|
||||||
|
# 6. Prepare latent variables
|
||||||
|
num_channels_latents = self.controlnet.in_channels
|
||||||
|
latents = self.prepare_latents(
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
num_channels_latents,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
prompt_embeds.dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
latents,
|
||||||
|
)
|
||||||
|
|
||||||
|
# EXTRA: prepare mask latents
|
||||||
|
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
|
||||||
|
mask, masked_image_latents = self.prepare_mask_latents(
|
||||||
|
mask,
|
||||||
|
masked_image,
|
||||||
|
batch_size * num_images_per_prompt,
|
||||||
|
height,
|
||||||
|
width,
|
||||||
|
prompt_embeds.dtype,
|
||||||
|
device,
|
||||||
|
generator,
|
||||||
|
do_classifier_free_guidance,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||||
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||||
|
|
||||||
|
# 8. Denoising loop
|
||||||
|
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||||
|
for i, t in enumerate(timesteps):
|
||||||
|
# expand the latents if we are doing classifier free guidance
|
||||||
|
latent_model_input = (
|
||||||
|
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||||
|
)
|
||||||
|
latent_model_input = self.scheduler.scale_model_input(
|
||||||
|
latent_model_input, t
|
||||||
|
)
|
||||||
|
|
||||||
|
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
||||||
|
latent_model_input,
|
||||||
|
t,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
controlnet_cond=control_image,
|
||||||
|
return_dict=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
down_block_res_samples = [
|
||||||
|
down_block_res_sample * controlnet_conditioning_scale
|
||||||
|
for down_block_res_sample in down_block_res_samples
|
||||||
|
]
|
||||||
|
mid_block_res_sample *= controlnet_conditioning_scale
|
||||||
|
|
||||||
|
# predict the noise residual
|
||||||
|
latent_model_input = torch.cat(
|
||||||
|
[latent_model_input, mask, masked_image_latents], dim=1
|
||||||
|
)
|
||||||
|
noise_pred = self.unet(
|
||||||
|
latent_model_input,
|
||||||
|
t,
|
||||||
|
encoder_hidden_states=prompt_embeds,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
|
down_block_additional_residuals=down_block_res_samples,
|
||||||
|
mid_block_additional_residual=mid_block_res_sample,
|
||||||
|
).sample
|
||||||
|
|
||||||
|
# perform guidance
|
||||||
|
if do_classifier_free_guidance:
|
||||||
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||||
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||||
|
noise_pred_text - noise_pred_uncond
|
||||||
|
)
|
||||||
|
|
||||||
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
|
latents = self.scheduler.step(
|
||||||
|
noise_pred, t, latents, **extra_step_kwargs
|
||||||
|
).prev_sample
|
||||||
|
|
||||||
|
# call the callback, if provided
|
||||||
|
if i == len(timesteps) - 1 or (
|
||||||
|
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
||||||
|
):
|
||||||
|
progress_bar.update()
|
||||||
|
if callback is not None and i % callback_steps == 0:
|
||||||
|
callback(i, t, latents)
|
||||||
|
|
||||||
|
# If we do sequential model offloading, let's offload unet and controlnet
|
||||||
|
# manually for max memory savings
|
||||||
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||||
|
self.unet.to("cpu")
|
||||||
|
self.controlnet.to("cpu")
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
if output_type == "latent":
|
||||||
|
image = latents
|
||||||
|
has_nsfw_concept = None
|
||||||
|
elif output_type == "pil":
|
||||||
|
# 8. Post-processing
|
||||||
|
image = self.decode_latents(latents)
|
||||||
|
|
||||||
|
# 9. Run safety checker
|
||||||
|
image, has_nsfw_concept = self.run_safety_checker(
|
||||||
|
image, device, prompt_embeds.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# 10. Convert to PIL
|
||||||
|
image = self.numpy_to_pil(image)
|
||||||
|
else:
|
||||||
|
# 8. Post-processing
|
||||||
|
image = self.decode_latents(latents)
|
||||||
|
|
||||||
|
# 9. Run safety checker
|
||||||
|
image, has_nsfw_concept = self.run_safety_checker(
|
||||||
|
image, device, prompt_embeds.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Offload last model to CPU
|
||||||
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||||
|
self.final_offload_hook.offload()
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return (image, has_nsfw_concept)
|
||||||
|
|
||||||
|
return StableDiffusionPipelineOutput(
|
||||||
|
images=image, nsfw_content_detected=has_nsfw_concept
|
||||||
|
)
|
@ -1,22 +1,12 @@
|
|||||||
import random
|
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from diffusers import (
|
|
||||||
PNDMScheduler,
|
|
||||||
DDIMScheduler,
|
|
||||||
LMSDiscreteScheduler,
|
|
||||||
EulerDiscreteScheduler,
|
|
||||||
EulerAncestralDiscreteScheduler,
|
|
||||||
DPMSolverMultistepScheduler,
|
|
||||||
)
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.model.base import DiffusionInpaintModel
|
from lama_cleaner.model.base import DiffusionInpaintModel
|
||||||
from lama_cleaner.model.utils import torch_gc, set_seed
|
from lama_cleaner.model.utils import torch_gc, get_scheduler
|
||||||
from lama_cleaner.schema import Config, SDSampler
|
from lama_cleaner.schema import Config
|
||||||
|
|
||||||
|
|
||||||
class CPUTextEncoderWrapper:
|
class CPUTextEncoderWrapper:
|
||||||
@ -101,22 +91,7 @@ class SD(DiffusionInpaintModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
scheduler_config = self.model.scheduler.config
|
scheduler_config = self.model.scheduler.config
|
||||||
|
scheduler = get_scheduler(config.sd_sampler, scheduler_config)
|
||||||
if config.sd_sampler == SDSampler.ddim:
|
|
||||||
scheduler = DDIMScheduler.from_config(scheduler_config)
|
|
||||||
elif config.sd_sampler == SDSampler.pndm:
|
|
||||||
scheduler = PNDMScheduler.from_config(scheduler_config)
|
|
||||||
elif config.sd_sampler == SDSampler.k_lms:
|
|
||||||
scheduler = LMSDiscreteScheduler.from_config(scheduler_config)
|
|
||||||
elif config.sd_sampler == SDSampler.k_euler:
|
|
||||||
scheduler = EulerDiscreteScheduler.from_config(scheduler_config)
|
|
||||||
elif config.sd_sampler == SDSampler.k_euler_a:
|
|
||||||
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler_config)
|
|
||||||
elif config.sd_sampler == SDSampler.dpm_plus_plus:
|
|
||||||
scheduler = DPMSolverMultistepScheduler.from_config(scheduler_config)
|
|
||||||
else:
|
|
||||||
raise ValueError(config.sd_sampler)
|
|
||||||
|
|
||||||
self.model.scheduler = scheduler
|
self.model.scheduler = scheduler
|
||||||
|
|
||||||
if config.sd_mask_blur != 0:
|
if config.sd_mask_blur != 0:
|
||||||
|
@ -7,17 +7,35 @@ import numpy as np
|
|||||||
import collections
|
import collections
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
|
|
||||||
|
from diffusers import (
|
||||||
|
DDIMScheduler,
|
||||||
|
PNDMScheduler,
|
||||||
|
LMSDiscreteScheduler,
|
||||||
|
EulerDiscreteScheduler,
|
||||||
|
EulerAncestralDiscreteScheduler,
|
||||||
|
DPMSolverMultistepScheduler,
|
||||||
|
UniPCMultistepScheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
from lama_cleaner.schema import SDSampler
|
||||||
from torch import conv2d, conv_transpose2d
|
from torch import conv2d, conv_transpose2d
|
||||||
|
|
||||||
|
|
||||||
def make_beta_schedule(device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
def make_beta_schedule(
|
||||||
|
device, schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3
|
||||||
|
):
|
||||||
if schedule == "linear":
|
if schedule == "linear":
|
||||||
betas = (
|
betas = (
|
||||||
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
torch.linspace(
|
||||||
|
linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64
|
||||||
|
)
|
||||||
|
** 2
|
||||||
)
|
)
|
||||||
|
|
||||||
elif schedule == "cosine":
|
elif schedule == "cosine":
|
||||||
timesteps = (torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s).to(device)
|
timesteps = (
|
||||||
|
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
||||||
|
).to(device)
|
||||||
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||||
alphas = torch.cos(alphas).pow(2).to(device)
|
alphas = torch.cos(alphas).pow(2).to(device)
|
||||||
alphas = alphas / alphas[0]
|
alphas = alphas / alphas[0]
|
||||||
@ -25,9 +43,14 @@ def make_beta_schedule(device, schedule, n_timestep, linear_start=1e-4, linear_e
|
|||||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||||
|
|
||||||
elif schedule == "sqrt_linear":
|
elif schedule == "sqrt_linear":
|
||||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
betas = torch.linspace(
|
||||||
|
linear_start, linear_end, n_timestep, dtype=torch.float64
|
||||||
|
)
|
||||||
elif schedule == "sqrt":
|
elif schedule == "sqrt":
|
||||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
betas = (
|
||||||
|
torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||||
|
** 0.5
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"schedule '{schedule}' unknown.")
|
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||||
return betas.numpy()
|
return betas.numpy()
|
||||||
@ -39,33 +62,47 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
|||||||
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||||
|
|
||||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
sigmas = eta * np.sqrt(
|
||||||
|
(1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
|
||||||
|
)
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
print(
|
||||||
print(f'For the chosen value of eta, which is {eta}, '
|
f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}"
|
||||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
)
|
||||||
|
print(
|
||||||
|
f"For the chosen value of eta, which is {eta}, "
|
||||||
|
f"this results in the following sigma_t schedule for ddim sampler {sigmas}"
|
||||||
|
)
|
||||||
return sigmas, alphas, alphas_prev
|
return sigmas, alphas, alphas_prev
|
||||||
|
|
||||||
|
|
||||||
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
def make_ddim_timesteps(
|
||||||
if ddim_discr_method == 'uniform':
|
ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True
|
||||||
|
):
|
||||||
|
if ddim_discr_method == "uniform":
|
||||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||||
elif ddim_discr_method == 'quad':
|
elif ddim_discr_method == "quad":
|
||||||
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
ddim_timesteps = (
|
||||||
|
(np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2
|
||||||
|
).astype(int)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
raise NotImplementedError(
|
||||||
|
f'There is no ddim discretization method called "{ddim_discr_method}"'
|
||||||
|
)
|
||||||
|
|
||||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||||
steps_out = ddim_timesteps + 1
|
steps_out = ddim_timesteps + 1
|
||||||
if verbose:
|
if verbose:
|
||||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
print(f"Selected timesteps for ddim sampler: {steps_out}")
|
||||||
return steps_out
|
return steps_out
|
||||||
|
|
||||||
|
|
||||||
def noise_like(shape, device, repeat=False):
|
def noise_like(shape, device, repeat=False):
|
||||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(
|
||||||
|
shape[0], *((1,) * (len(shape) - 1))
|
||||||
|
)
|
||||||
noise = lambda: torch.randn(shape, device=device)
|
noise = lambda: torch.randn(shape, device=device)
|
||||||
return repeat_noise() if repeat else noise()
|
return repeat_noise() if repeat else noise()
|
||||||
|
|
||||||
@ -81,7 +118,9 @@ def timestep_embedding(device, timesteps, dim, max_period=10000, repeat_only=Fal
|
|||||||
"""
|
"""
|
||||||
half = dim // 2
|
half = dim // 2
|
||||||
freqs = torch.exp(
|
freqs = torch.exp(
|
||||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
-math.log(max_period)
|
||||||
|
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||||
|
/ half
|
||||||
).to(device=device)
|
).to(device=device)
|
||||||
|
|
||||||
args = timesteps[:, None].float() * freqs[None]
|
args = timesteps[:, None].float() * freqs[None]
|
||||||
@ -115,9 +154,8 @@ class EasyDict(dict):
|
|||||||
del self[name]
|
del self[name]
|
||||||
|
|
||||||
|
|
||||||
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
def _bias_act_ref(x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None):
|
||||||
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
|
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops."""
|
||||||
"""
|
|
||||||
assert isinstance(x, torch.Tensor)
|
assert isinstance(x, torch.Tensor)
|
||||||
assert clamp is None or clamp >= 0
|
assert clamp is None or clamp >= 0
|
||||||
spec = activation_funcs[act]
|
spec = activation_funcs[act]
|
||||||
@ -147,7 +185,9 @@ def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=N
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='ref'):
|
def bias_act(
|
||||||
|
x, b=None, dim=1, act="linear", alpha=None, gain=None, clamp=None, impl="ref"
|
||||||
|
):
|
||||||
r"""Fused bias and activation function.
|
r"""Fused bias and activation function.
|
||||||
|
|
||||||
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
||||||
@ -178,8 +218,10 @@ def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None,
|
|||||||
Tensor of the same shape and datatype as `x`.
|
Tensor of the same shape and datatype as `x`.
|
||||||
"""
|
"""
|
||||||
assert isinstance(x, torch.Tensor)
|
assert isinstance(x, torch.Tensor)
|
||||||
assert impl in ['ref', 'cuda']
|
assert impl in ["ref", "cuda"]
|
||||||
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
|
return _bias_act_ref(
|
||||||
|
x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_filter_size(f):
|
def _get_filter_size(f):
|
||||||
@ -223,7 +265,14 @@ def _parse_padding(padding):
|
|||||||
return padx0, padx1, pady0, pady1
|
return padx0, padx1, pady0, pady1
|
||||||
|
|
||||||
|
|
||||||
def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=False, gain=1, separable=None):
|
def setup_filter(
|
||||||
|
f,
|
||||||
|
device=torch.device("cpu"),
|
||||||
|
normalize=True,
|
||||||
|
flip_filter=False,
|
||||||
|
gain=1,
|
||||||
|
separable=None,
|
||||||
|
):
|
||||||
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
r"""Convenience function to setup 2D FIR filter for `upfirdn2d()`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -255,7 +304,7 @@ def setup_filter(f, device=torch.device('cpu'), normalize=True, flip_filter=Fals
|
|||||||
|
|
||||||
# Separable?
|
# Separable?
|
||||||
if separable is None:
|
if separable is None:
|
||||||
separable = (f.ndim == 1 and f.numel() >= 8)
|
separable = f.ndim == 1 and f.numel() >= 8
|
||||||
if f.ndim == 1 and not separable:
|
if f.ndim == 1 and not separable:
|
||||||
f = f.ger(f)
|
f = f.ger(f)
|
||||||
assert f.ndim == (1 if separable else 2)
|
assert f.ndim == (1 if separable else 2)
|
||||||
@ -282,27 +331,82 @@ def _ntuple(n):
|
|||||||
to_2tuple = _ntuple(2)
|
to_2tuple = _ntuple(2)
|
||||||
|
|
||||||
activation_funcs = {
|
activation_funcs = {
|
||||||
'linear': EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
|
"linear": EasyDict(
|
||||||
'relu': EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2,
|
func=lambda x, **_: x,
|
||||||
ref='y', has_2nd_grad=False),
|
def_alpha=0,
|
||||||
'lrelu': EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2,
|
def_gain=1,
|
||||||
def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
|
cuda_idx=1,
|
||||||
'tanh': EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y',
|
ref="",
|
||||||
has_2nd_grad=True),
|
has_2nd_grad=False,
|
||||||
'sigmoid': EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y',
|
),
|
||||||
has_2nd_grad=True),
|
"relu": EasyDict(
|
||||||
'elu': EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y',
|
func=lambda x, **_: torch.nn.functional.relu(x),
|
||||||
has_2nd_grad=True),
|
def_alpha=0,
|
||||||
'selu': EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y',
|
def_gain=np.sqrt(2),
|
||||||
has_2nd_grad=True),
|
cuda_idx=2,
|
||||||
'softplus': EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8,
|
ref="y",
|
||||||
ref='y', has_2nd_grad=True),
|
has_2nd_grad=False,
|
||||||
'swish': EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x',
|
),
|
||||||
has_2nd_grad=True),
|
"lrelu": EasyDict(
|
||||||
|
func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha),
|
||||||
|
def_alpha=0.2,
|
||||||
|
def_gain=np.sqrt(2),
|
||||||
|
cuda_idx=3,
|
||||||
|
ref="y",
|
||||||
|
has_2nd_grad=False,
|
||||||
|
),
|
||||||
|
"tanh": EasyDict(
|
||||||
|
func=lambda x, **_: torch.tanh(x),
|
||||||
|
def_alpha=0,
|
||||||
|
def_gain=1,
|
||||||
|
cuda_idx=4,
|
||||||
|
ref="y",
|
||||||
|
has_2nd_grad=True,
|
||||||
|
),
|
||||||
|
"sigmoid": EasyDict(
|
||||||
|
func=lambda x, **_: torch.sigmoid(x),
|
||||||
|
def_alpha=0,
|
||||||
|
def_gain=1,
|
||||||
|
cuda_idx=5,
|
||||||
|
ref="y",
|
||||||
|
has_2nd_grad=True,
|
||||||
|
),
|
||||||
|
"elu": EasyDict(
|
||||||
|
func=lambda x, **_: torch.nn.functional.elu(x),
|
||||||
|
def_alpha=0,
|
||||||
|
def_gain=1,
|
||||||
|
cuda_idx=6,
|
||||||
|
ref="y",
|
||||||
|
has_2nd_grad=True,
|
||||||
|
),
|
||||||
|
"selu": EasyDict(
|
||||||
|
func=lambda x, **_: torch.nn.functional.selu(x),
|
||||||
|
def_alpha=0,
|
||||||
|
def_gain=1,
|
||||||
|
cuda_idx=7,
|
||||||
|
ref="y",
|
||||||
|
has_2nd_grad=True,
|
||||||
|
),
|
||||||
|
"softplus": EasyDict(
|
||||||
|
func=lambda x, **_: torch.nn.functional.softplus(x),
|
||||||
|
def_alpha=0,
|
||||||
|
def_gain=1,
|
||||||
|
cuda_idx=8,
|
||||||
|
ref="y",
|
||||||
|
has_2nd_grad=True,
|
||||||
|
),
|
||||||
|
"swish": EasyDict(
|
||||||
|
func=lambda x, **_: torch.sigmoid(x) * x,
|
||||||
|
def_alpha=0,
|
||||||
|
def_gain=np.sqrt(2),
|
||||||
|
cuda_idx=9,
|
||||||
|
ref="x",
|
||||||
|
has_2nd_grad=True,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl="cuda"):
|
||||||
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
r"""Pad, upsample, filter, and downsample a batch of 2D images.
|
||||||
|
|
||||||
Performs the following sequence of operations for each channel:
|
Performs the following sequence of operations for each channel:
|
||||||
@ -344,12 +448,13 @@ def upfirdn2d(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1, impl='cu
|
|||||||
"""
|
"""
|
||||||
# assert isinstance(x, torch.Tensor)
|
# assert isinstance(x, torch.Tensor)
|
||||||
# assert impl in ['ref', 'cuda']
|
# assert impl in ['ref', 'cuda']
|
||||||
return _upfirdn2d_ref(x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain)
|
return _upfirdn2d_ref(
|
||||||
|
x, f, up=up, down=down, padding=padding, flip_filter=flip_filter, gain=gain
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
||||||
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops.
|
"""Slow reference implementation of `upfirdn2d()` using standard PyTorch ops."""
|
||||||
"""
|
|
||||||
# Validate arguments.
|
# Validate arguments.
|
||||||
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
assert isinstance(x, torch.Tensor) and x.ndim == 4
|
||||||
if f is None:
|
if f is None:
|
||||||
@ -372,8 +477,15 @@ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
|||||||
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
x = x.reshape([batch_size, num_channels, in_height * upy, in_width * upx])
|
||||||
|
|
||||||
# Pad or crop.
|
# Pad or crop.
|
||||||
x = torch.nn.functional.pad(x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)])
|
x = torch.nn.functional.pad(
|
||||||
x = x[:, :, max(-pady0, 0): x.shape[2] - max(-pady1, 0), max(-padx0, 0): x.shape[3] - max(-padx1, 0)]
|
x, [max(padx0, 0), max(padx1, 0), max(pady0, 0), max(pady1, 0)]
|
||||||
|
)
|
||||||
|
x = x[
|
||||||
|
:,
|
||||||
|
:,
|
||||||
|
max(-pady0, 0) : x.shape[2] - max(-pady1, 0),
|
||||||
|
max(-padx0, 0) : x.shape[3] - max(-padx1, 0),
|
||||||
|
]
|
||||||
|
|
||||||
# Setup filter.
|
# Setup filter.
|
||||||
f = f * (gain ** (f.ndim / 2))
|
f = f * (gain ** (f.ndim / 2))
|
||||||
@ -394,7 +506,7 @@ def _upfirdn2d_ref(x, f, up=1, down=1, padding=0, flip_filter=False, gain=1):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
|
||||||
r"""Downsample a batch of 2D images using the given 2D FIR filter.
|
r"""Downsample a batch of 2D images using the given 2D FIR filter.
|
||||||
|
|
||||||
By default, the result is padded so that its shape is a fraction of the input.
|
By default, the result is padded so that its shape is a fraction of the input.
|
||||||
@ -431,10 +543,12 @@ def downsample2d(x, f, down=2, padding=0, flip_filter=False, gain=1, impl='cuda'
|
|||||||
pady0 + (fh - downy + 1) // 2,
|
pady0 + (fh - downy + 1) // 2,
|
||||||
pady1 + (fh - downy) // 2,
|
pady1 + (fh - downy) // 2,
|
||||||
]
|
]
|
||||||
return upfirdn2d(x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl)
|
return upfirdn2d(
|
||||||
|
x, f, down=down, padding=p, flip_filter=flip_filter, gain=gain, impl=impl
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl="cuda"):
|
||||||
r"""Upsample a batch of 2D images using the given 2D FIR filter.
|
r"""Upsample a batch of 2D images using the given 2D FIR filter.
|
||||||
|
|
||||||
By default, the result is padded so that its shape is a multiple of the input.
|
By default, the result is padded so that its shape is a multiple of the input.
|
||||||
@ -471,7 +585,15 @@ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1, impl='cuda'):
|
|||||||
pady0 + (fh + upy - 1) // 2,
|
pady0 + (fh + upy - 1) // 2,
|
||||||
pady1 + (fh - upy) // 2,
|
pady1 + (fh - upy) // 2,
|
||||||
]
|
]
|
||||||
return upfirdn2d(x, f, up=up, padding=p, flip_filter=flip_filter, gain=gain * upx * upy, impl=impl)
|
return upfirdn2d(
|
||||||
|
x,
|
||||||
|
f,
|
||||||
|
up=up,
|
||||||
|
padding=p,
|
||||||
|
flip_filter=flip_filter,
|
||||||
|
gain=gain * upx * upy,
|
||||||
|
impl=impl,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MinibatchStdLayer(torch.nn.Module):
|
class MinibatchStdLayer(torch.nn.Module):
|
||||||
@ -482,13 +604,17 @@ class MinibatchStdLayer(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
N, C, H, W = x.shape
|
N, C, H, W = x.shape
|
||||||
G = torch.min(torch.as_tensor(self.group_size),
|
G = (
|
||||||
torch.as_tensor(N)) if self.group_size is not None else N
|
torch.min(torch.as_tensor(self.group_size), torch.as_tensor(N))
|
||||||
|
if self.group_size is not None
|
||||||
|
else N
|
||||||
|
)
|
||||||
F = self.num_channels
|
F = self.num_channels
|
||||||
c = C // F
|
c = C // F
|
||||||
|
|
||||||
y = x.reshape(G, -1, F, c, H,
|
y = x.reshape(
|
||||||
W) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
|
G, -1, F, c, H, W
|
||||||
|
) # [GnFcHW] Split minibatch N into n groups of size G, and channels C into F groups of size c.
|
||||||
y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
|
y = y - y.mean(dim=0) # [GnFcHW] Subtract mean over group.
|
||||||
y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
|
y = y.square().mean(dim=0) # [nFcHW] Calc variance over group.
|
||||||
y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
|
y = (y + 1e-8).sqrt() # [nFcHW] Calc stddev over group.
|
||||||
@ -500,17 +626,24 @@ class MinibatchStdLayer(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FullyConnectedLayer(torch.nn.Module):
|
class FullyConnectedLayer(torch.nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
in_features, # Number of input features.
|
self,
|
||||||
out_features, # Number of output features.
|
in_features, # Number of input features.
|
||||||
bias=True, # Apply additive bias before the activation function?
|
out_features, # Number of output features.
|
||||||
activation='linear', # Activation function: 'relu', 'lrelu', etc.
|
bias=True, # Apply additive bias before the activation function?
|
||||||
lr_multiplier=1, # Learning rate multiplier.
|
activation="linear", # Activation function: 'relu', 'lrelu', etc.
|
||||||
bias_init=0, # Initial value for the additive bias.
|
lr_multiplier=1, # Learning rate multiplier.
|
||||||
):
|
bias_init=0, # Initial value for the additive bias.
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) / lr_multiplier)
|
self.weight = torch.nn.Parameter(
|
||||||
self.bias = torch.nn.Parameter(torch.full([out_features], np.float32(bias_init))) if bias else None
|
torch.randn([out_features, in_features]) / lr_multiplier
|
||||||
|
)
|
||||||
|
self.bias = (
|
||||||
|
torch.nn.Parameter(torch.full([out_features], np.float32(bias_init)))
|
||||||
|
if bias
|
||||||
|
else None
|
||||||
|
)
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
|
|
||||||
self.weight_gain = lr_multiplier / np.sqrt(in_features)
|
self.weight_gain = lr_multiplier / np.sqrt(in_features)
|
||||||
@ -522,7 +655,7 @@ class FullyConnectedLayer(torch.nn.Module):
|
|||||||
if b is not None and self.bias_gain != 1:
|
if b is not None and self.bias_gain != 1:
|
||||||
b = b * self.bias_gain
|
b = b * self.bias_gain
|
||||||
|
|
||||||
if self.activation == 'linear' and b is not None:
|
if self.activation == "linear" and b is not None:
|
||||||
# out = torch.addmm(b.unsqueeze(0), x, w.t())
|
# out = torch.addmm(b.unsqueeze(0), x, w.t())
|
||||||
x = x.matmul(w.t())
|
x = x.matmul(w.t())
|
||||||
out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)])
|
out = x + b.reshape([-1 if i == x.ndim - 1 else 1 for i in range(x.ndim)])
|
||||||
@ -532,22 +665,33 @@ class FullyConnectedLayer(torch.nn.Module):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True):
|
def _conv2d_wrapper(
|
||||||
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations.
|
x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True
|
||||||
"""
|
):
|
||||||
|
"""Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations."""
|
||||||
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w)
|
||||||
|
|
||||||
# Flip weight if requested.
|
# Flip weight if requested.
|
||||||
if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
|
if (
|
||||||
|
not flip_weight
|
||||||
|
): # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False).
|
||||||
w = w.flip([2, 3])
|
w = w.flip([2, 3])
|
||||||
|
|
||||||
# Workaround performance pitfall in cuDNN 8.0.5, triggered when using
|
# Workaround performance pitfall in cuDNN 8.0.5, triggered when using
|
||||||
# 1x1 kernel + memory_format=channels_last + less than 64 channels.
|
# 1x1 kernel + memory_format=channels_last + less than 64 channels.
|
||||||
if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose:
|
if (
|
||||||
|
kw == 1
|
||||||
|
and kh == 1
|
||||||
|
and stride == 1
|
||||||
|
and padding in [0, [0, 0], (0, 0)]
|
||||||
|
and not transpose
|
||||||
|
):
|
||||||
if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
|
if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64:
|
||||||
if out_channels <= 4 and groups == 1:
|
if out_channels <= 4 and groups == 1:
|
||||||
in_shape = x.shape
|
in_shape = x.shape
|
||||||
x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1])
|
x = w.squeeze(3).squeeze(2) @ x.reshape(
|
||||||
|
[in_shape[0], in_channels_per_group, -1]
|
||||||
|
)
|
||||||
x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
|
x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]])
|
||||||
else:
|
else:
|
||||||
x = x.to(memory_format=torch.contiguous_format)
|
x = x.to(memory_format=torch.contiguous_format)
|
||||||
@ -560,7 +704,9 @@ def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_w
|
|||||||
return op(x, w, stride=stride, padding=padding, groups=groups)
|
return op(x, w, stride=stride, padding=padding, groups=groups)
|
||||||
|
|
||||||
|
|
||||||
def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False):
|
def conv2d_resample(
|
||||||
|
x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False
|
||||||
|
):
|
||||||
r"""2D convolution with optional up/downsampling.
|
r"""2D convolution with optional up/downsampling.
|
||||||
|
|
||||||
Padding is performed only once at the beginning, not between the operations.
|
Padding is performed only once at the beginning, not between the operations.
|
||||||
@ -587,7 +733,9 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
|
|||||||
# Validate arguments.
|
# Validate arguments.
|
||||||
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
|
assert isinstance(x, torch.Tensor) and (x.ndim == 4)
|
||||||
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
|
assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype)
|
||||||
assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32)
|
assert f is None or (
|
||||||
|
isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32
|
||||||
|
)
|
||||||
assert isinstance(up, int) and (up >= 1)
|
assert isinstance(up, int) and (up >= 1)
|
||||||
assert isinstance(down, int) and (down >= 1)
|
assert isinstance(down, int) and (down >= 1)
|
||||||
# assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}"
|
# assert isinstance(groups, int) and (groups >= 1), f"!!!!!! groups: {groups} isinstance(groups, int) {isinstance(groups, int)} {type(groups)}"
|
||||||
@ -610,20 +758,31 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
|
|||||||
|
|
||||||
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
|
# Fast path: 1x1 convolution with downsampling only => downsample first, then convolve.
|
||||||
if kw == 1 and kh == 1 and (down > 1 and up == 1):
|
if kw == 1 and kh == 1 and (down > 1 and up == 1):
|
||||||
x = upfirdn2d(x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
|
x = upfirdn2d(
|
||||||
|
x=x, f=f, down=down, padding=[px0, px1, py0, py1], flip_filter=flip_filter
|
||||||
|
)
|
||||||
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
|
# Fast path: 1x1 convolution with upsampling only => convolve first, then upsample.
|
||||||
if kw == 1 and kh == 1 and (up > 1 and down == 1):
|
if kw == 1 and kh == 1 and (up > 1 and down == 1):
|
||||||
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
||||||
x = upfirdn2d(x=x, f=f, up=up, padding=[px0, px1, py0, py1], gain=up ** 2, flip_filter=flip_filter)
|
x = upfirdn2d(
|
||||||
|
x=x,
|
||||||
|
f=f,
|
||||||
|
up=up,
|
||||||
|
padding=[px0, px1, py0, py1],
|
||||||
|
gain=up**2,
|
||||||
|
flip_filter=flip_filter,
|
||||||
|
)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
# Fast path: downsampling only => use strided convolution.
|
# Fast path: downsampling only => use strided convolution.
|
||||||
if down > 1 and up == 1:
|
if down > 1 and up == 1:
|
||||||
x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
|
x = upfirdn2d(x=x, f=f, padding=[px0, px1, py0, py1], flip_filter=flip_filter)
|
||||||
x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight)
|
x = _conv2d_wrapper(
|
||||||
|
x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight
|
||||||
|
)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
|
# Fast path: upsampling with optional downsampling => use transpose strided convolution.
|
||||||
@ -633,17 +792,31 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
|
|||||||
else:
|
else:
|
||||||
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
|
w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw)
|
||||||
w = w.transpose(1, 2)
|
w = w.transpose(1, 2)
|
||||||
w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw)
|
w = w.reshape(
|
||||||
|
groups * in_channels_per_group, out_channels // groups, kh, kw
|
||||||
|
)
|
||||||
px0 -= kw - 1
|
px0 -= kw - 1
|
||||||
px1 -= kw - up
|
px1 -= kw - up
|
||||||
py0 -= kh - 1
|
py0 -= kh - 1
|
||||||
py1 -= kh - up
|
py1 -= kh - up
|
||||||
pxt = max(min(-px0, -px1), 0)
|
pxt = max(min(-px0, -px1), 0)
|
||||||
pyt = max(min(-py0, -py1), 0)
|
pyt = max(min(-py0, -py1), 0)
|
||||||
x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt, pxt], groups=groups, transpose=True,
|
x = _conv2d_wrapper(
|
||||||
flip_weight=(not flip_weight))
|
x=x,
|
||||||
x = upfirdn2d(x=x, f=f, padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt], gain=up ** 2,
|
w=w,
|
||||||
flip_filter=flip_filter)
|
stride=up,
|
||||||
|
padding=[pyt, pxt],
|
||||||
|
groups=groups,
|
||||||
|
transpose=True,
|
||||||
|
flip_weight=(not flip_weight),
|
||||||
|
)
|
||||||
|
x = upfirdn2d(
|
||||||
|
x=x,
|
||||||
|
f=f,
|
||||||
|
padding=[px0 + pxt, px1 + pxt, py0 + pyt, py1 + pyt],
|
||||||
|
gain=up**2,
|
||||||
|
flip_filter=flip_filter,
|
||||||
|
)
|
||||||
if down > 1:
|
if down > 1:
|
||||||
x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
||||||
return x
|
return x
|
||||||
@ -651,11 +824,19 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
|
|||||||
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
|
# Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d.
|
||||||
if up == 1 and down == 1:
|
if up == 1 and down == 1:
|
||||||
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
|
if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0:
|
||||||
return _conv2d_wrapper(x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight)
|
return _conv2d_wrapper(
|
||||||
|
x=x, w=w, padding=[py0, px0], groups=groups, flip_weight=flip_weight
|
||||||
|
)
|
||||||
|
|
||||||
# Fallback: Generic reference implementation.
|
# Fallback: Generic reference implementation.
|
||||||
x = upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0, px1, py0, py1], gain=up ** 2,
|
x = upfirdn2d(
|
||||||
flip_filter=flip_filter)
|
x=x,
|
||||||
|
f=(f if up > 1 else None),
|
||||||
|
up=up,
|
||||||
|
padding=[px0, px1, py0, py1],
|
||||||
|
gain=up**2,
|
||||||
|
flip_filter=flip_filter,
|
||||||
|
)
|
||||||
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight)
|
||||||
if down > 1:
|
if down > 1:
|
||||||
x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
x = upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter)
|
||||||
@ -663,50 +844,68 @@ def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight
|
|||||||
|
|
||||||
|
|
||||||
class Conv2dLayer(torch.nn.Module):
|
class Conv2dLayer(torch.nn.Module):
|
||||||
def __init__(self,
|
def __init__(
|
||||||
in_channels, # Number of input channels.
|
self,
|
||||||
out_channels, # Number of output channels.
|
in_channels, # Number of input channels.
|
||||||
kernel_size, # Width and height of the convolution kernel.
|
out_channels, # Number of output channels.
|
||||||
bias=True, # Apply additive bias before the activation function?
|
kernel_size, # Width and height of the convolution kernel.
|
||||||
activation='linear', # Activation function: 'relu', 'lrelu', etc.
|
bias=True, # Apply additive bias before the activation function?
|
||||||
up=1, # Integer upsampling factor.
|
activation="linear", # Activation function: 'relu', 'lrelu', etc.
|
||||||
down=1, # Integer downsampling factor.
|
up=1, # Integer upsampling factor.
|
||||||
resample_filter=[1, 3, 3, 1], # Low-pass filter to apply when resampling activations.
|
down=1, # Integer downsampling factor.
|
||||||
conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
|
resample_filter=[
|
||||||
channels_last=False, # Expect the input to have memory_format=channels_last?
|
1,
|
||||||
trainable=True, # Update the weights of this layer during training?
|
3,
|
||||||
):
|
3,
|
||||||
|
1,
|
||||||
|
], # Low-pass filter to apply when resampling activations.
|
||||||
|
conv_clamp=None, # Clamp the output to +-X, None = disable clamping.
|
||||||
|
channels_last=False, # Expect the input to have memory_format=channels_last?
|
||||||
|
trainable=True, # Update the weights of this layer during training?
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.activation = activation
|
self.activation = activation
|
||||||
self.up = up
|
self.up = up
|
||||||
self.down = down
|
self.down = down
|
||||||
self.register_buffer('resample_filter', setup_filter(resample_filter))
|
self.register_buffer("resample_filter", setup_filter(resample_filter))
|
||||||
self.conv_clamp = conv_clamp
|
self.conv_clamp = conv_clamp
|
||||||
self.padding = kernel_size // 2
|
self.padding = kernel_size // 2
|
||||||
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size ** 2))
|
self.weight_gain = 1 / np.sqrt(in_channels * (kernel_size**2))
|
||||||
self.act_gain = activation_funcs[activation].def_gain
|
self.act_gain = activation_funcs[activation].def_gain
|
||||||
|
|
||||||
memory_format = torch.channels_last if channels_last else torch.contiguous_format
|
memory_format = (
|
||||||
weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(memory_format=memory_format)
|
torch.channels_last if channels_last else torch.contiguous_format
|
||||||
|
)
|
||||||
|
weight = torch.randn([out_channels, in_channels, kernel_size, kernel_size]).to(
|
||||||
|
memory_format=memory_format
|
||||||
|
)
|
||||||
bias = torch.zeros([out_channels]) if bias else None
|
bias = torch.zeros([out_channels]) if bias else None
|
||||||
if trainable:
|
if trainable:
|
||||||
self.weight = torch.nn.Parameter(weight)
|
self.weight = torch.nn.Parameter(weight)
|
||||||
self.bias = torch.nn.Parameter(bias) if bias is not None else None
|
self.bias = torch.nn.Parameter(bias) if bias is not None else None
|
||||||
else:
|
else:
|
||||||
self.register_buffer('weight', weight)
|
self.register_buffer("weight", weight)
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
self.register_buffer('bias', bias)
|
self.register_buffer("bias", bias)
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
|
|
||||||
def forward(self, x, gain=1):
|
def forward(self, x, gain=1):
|
||||||
w = self.weight * self.weight_gain
|
w = self.weight * self.weight_gain
|
||||||
x = conv2d_resample(x=x, w=w, f=self.resample_filter, up=self.up, down=self.down,
|
x = conv2d_resample(
|
||||||
padding=self.padding)
|
x=x,
|
||||||
|
w=w,
|
||||||
|
f=self.resample_filter,
|
||||||
|
up=self.up,
|
||||||
|
down=self.down,
|
||||||
|
padding=self.padding,
|
||||||
|
)
|
||||||
|
|
||||||
act_gain = self.act_gain * gain
|
act_gain = self.act_gain * gain
|
||||||
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
|
||||||
out = bias_act(x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp)
|
out = bias_act(
|
||||||
|
x, self.bias, act=self.activation, gain=act_gain, clamp=act_clamp
|
||||||
|
)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
@ -721,3 +920,22 @@ def set_seed(seed: int):
|
|||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler(sd_sampler, scheduler_config):
|
||||||
|
if sd_sampler == SDSampler.ddim:
|
||||||
|
return DDIMScheduler.from_config(scheduler_config)
|
||||||
|
elif sd_sampler == SDSampler.pndm:
|
||||||
|
return PNDMScheduler.from_config(scheduler_config)
|
||||||
|
elif sd_sampler == SDSampler.k_lms:
|
||||||
|
return LMSDiscreteScheduler.from_config(scheduler_config)
|
||||||
|
elif sd_sampler == SDSampler.k_euler:
|
||||||
|
return EulerDiscreteScheduler.from_config(scheduler_config)
|
||||||
|
elif sd_sampler == SDSampler.k_euler_a:
|
||||||
|
return EulerAncestralDiscreteScheduler.from_config(scheduler_config)
|
||||||
|
elif sd_sampler == SDSampler.dpm_plus_plus:
|
||||||
|
return DPMSolverMultistepScheduler.from_config(scheduler_config)
|
||||||
|
elif sd_sampler == SDSampler.uni_pc:
|
||||||
|
return UniPCMultistepScheduler.from_config(scheduler_config)
|
||||||
|
else:
|
||||||
|
raise ValueError(sd_sampler)
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
import torch
|
import torch
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
|
from lama_cleaner.const import SD15_MODELS
|
||||||
from lama_cleaner.helper import switch_mps_device
|
from lama_cleaner.helper import switch_mps_device
|
||||||
|
from lama_cleaner.model.controlnet import ControlNet
|
||||||
from lama_cleaner.model.fcf import FcF
|
from lama_cleaner.model.fcf import FcF
|
||||||
from lama_cleaner.model.lama import LaMa
|
from lama_cleaner.model.lama import LaMa
|
||||||
from lama_cleaner.model.ldm import LDM
|
from lama_cleaner.model.ldm import LDM
|
||||||
@ -20,7 +22,7 @@ models = {
|
|||||||
"zits": ZITS,
|
"zits": ZITS,
|
||||||
"mat": MAT,
|
"mat": MAT,
|
||||||
"fcf": FcF,
|
"fcf": FcF,
|
||||||
"sd1.5": SD15,
|
SD15.name: SD15,
|
||||||
Anything4.name: Anything4,
|
Anything4.name: Anything4,
|
||||||
RealisticVision14.name: RealisticVision14,
|
RealisticVision14.name: RealisticVision14,
|
||||||
"cv2": OpenCV2,
|
"cv2": OpenCV2,
|
||||||
@ -39,6 +41,9 @@ class ModelManager:
|
|||||||
self.model = self.init_model(name, device, **kwargs)
|
self.model = self.init_model(name, device, **kwargs)
|
||||||
|
|
||||||
def init_model(self, name: str, device, **kwargs):
|
def init_model(self, name: str, device, **kwargs):
|
||||||
|
if name in SD15_MODELS and kwargs.get("sd_controlnet", False):
|
||||||
|
return ControlNet(device, **{**kwargs, "name": name})
|
||||||
|
|
||||||
if name in models:
|
if name in models:
|
||||||
model = models[name](device, **kwargs)
|
model = models[name](device, **kwargs)
|
||||||
else:
|
else:
|
||||||
|
@ -5,25 +5,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lama_cleaner.const import (
|
from lama_cleaner.const import *
|
||||||
AVAILABLE_MODELS,
|
|
||||||
NO_HALF_HELP,
|
|
||||||
CPU_OFFLOAD_HELP,
|
|
||||||
DISABLE_NSFW_HELP,
|
|
||||||
SD_CPU_TEXTENCODER_HELP,
|
|
||||||
LOCAL_FILES_ONLY_HELP,
|
|
||||||
AVAILABLE_DEVICES,
|
|
||||||
ENABLE_XFORMERS_HELP,
|
|
||||||
MODEL_DIR_HELP,
|
|
||||||
OUTPUT_DIR_HELP,
|
|
||||||
INPUT_HELP,
|
|
||||||
GUI_HELP,
|
|
||||||
DEFAULT_DEVICE,
|
|
||||||
NO_GUI_AUTO_CLOSE_HELP,
|
|
||||||
DEFAULT_MODEL_DIR,
|
|
||||||
DEFAULT_MODEL,
|
|
||||||
MPS_SUPPORT_MODELS,
|
|
||||||
)
|
|
||||||
from lama_cleaner.runtime import dump_environment_info
|
from lama_cleaner.runtime import dump_environment_info
|
||||||
|
|
||||||
|
|
||||||
@ -55,6 +37,9 @@ def parse_args():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP
|
"--sd-cpu-textencoder", action="store_true", help=SD_CPU_TEXTENCODER_HELP
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--sd-controlnet", action="store_true", help=SD_CONTROLNET_HELP
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP
|
"--local-files-only", action="store_true", help=LOCAL_FILES_ONLY_HELP
|
||||||
)
|
)
|
||||||
@ -133,6 +118,10 @@ def parse_args():
|
|||||||
"torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation"
|
"torch.cuda.is_available() is False, please use --device cpu or check your pytorch installation"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.sd_controlnet:
|
||||||
|
if args.model not in SD15_MODELS:
|
||||||
|
logger.warning(f"--sd_controlnet only support {SD15_MODELS}")
|
||||||
|
|
||||||
if args.model_dir and args.model_dir is not None:
|
if args.model_dir and args.model_dir is not None:
|
||||||
if os.path.isfile(args.model_dir):
|
if os.path.isfile(args.model_dir):
|
||||||
parser.error(f"invalid --model-dir: {args.model_dir} is a file")
|
parser.error(f"invalid --model-dir: {args.model_dir} is a file")
|
||||||
|
@ -24,9 +24,10 @@ class SDSampler(str, Enum):
|
|||||||
ddim = "ddim"
|
ddim = "ddim"
|
||||||
pndm = "pndm"
|
pndm = "pndm"
|
||||||
k_lms = "k_lms"
|
k_lms = "k_lms"
|
||||||
k_euler = 'k_euler'
|
k_euler = "k_euler"
|
||||||
k_euler_a = 'k_euler_a'
|
k_euler_a = "k_euler_a"
|
||||||
dpm_plus_plus = 'dpm++'
|
dpm_plus_plus = "dpm++"
|
||||||
|
uni_pc = "uni_pc"
|
||||||
|
|
||||||
|
|
||||||
class Config(BaseModel):
|
class Config(BaseModel):
|
||||||
@ -71,14 +72,14 @@ class Config(BaseModel):
|
|||||||
# Higher guidance scale encourages to generate images that are closely linked
|
# Higher guidance scale encourages to generate images that are closely linked
|
||||||
# to the text prompt, usually at the expense of lower image quality.
|
# to the text prompt, usually at the expense of lower image quality.
|
||||||
sd_guidance_scale: float = 7.5
|
sd_guidance_scale: float = 7.5
|
||||||
sd_sampler: str = SDSampler.ddim
|
sd_sampler: str = SDSampler.uni_pc
|
||||||
# -1 mean random seed
|
# -1 mean random seed
|
||||||
sd_seed: int = 42
|
sd_seed: int = 42
|
||||||
sd_match_histograms: bool = False
|
sd_match_histograms: bool = False
|
||||||
|
|
||||||
# Configs for opencv inpainting
|
# Configs for opencv inpainting
|
||||||
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
|
# opencv document https://docs.opencv.org/4.6.0/d7/d8b/group__photo__inpaint.html#gga8002a65f5a3328fbf15df81b842d3c3ca05e763003a805e6c11c673a9f4ba7d07
|
||||||
cv2_flag: str = 'INPAINT_NS'
|
cv2_flag: str = "INPAINT_NS"
|
||||||
cv2_radius: int = 4
|
cv2_radius: int = 4
|
||||||
|
|
||||||
# Paint by Example
|
# Paint by Example
|
||||||
@ -93,3 +94,6 @@ class Config(BaseModel):
|
|||||||
p2p_steps: int = 50
|
p2p_steps: int = 50
|
||||||
p2p_image_guidance_scale: float = 1.5
|
p2p_image_guidance_scale: float = 1.5
|
||||||
p2p_guidance_scale: float = 7.5
|
p2p_guidance_scale: float = 7.5
|
||||||
|
|
||||||
|
# ControlNet
|
||||||
|
controlnet_conditioning_scale: float = 0.4
|
||||||
|
@ -18,6 +18,7 @@ import numpy as np
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from watchdog.events import FileSystemEventHandler
|
from watchdog.events import FileSystemEventHandler
|
||||||
|
|
||||||
|
from lama_cleaner.const import SD15_MODELS
|
||||||
from lama_cleaner.interactive_seg import InteractiveSeg, Click
|
from lama_cleaner.interactive_seg import InteractiveSeg, Click
|
||||||
from lama_cleaner.make_gif import make_compare_gif
|
from lama_cleaner.make_gif import make_compare_gif
|
||||||
from lama_cleaner.model_manager import ModelManager
|
from lama_cleaner.model_manager import ModelManager
|
||||||
@ -88,6 +89,7 @@ interactive_seg_model: InteractiveSeg = None
|
|||||||
device = None
|
device = None
|
||||||
input_image_path: str = None
|
input_image_path: str = None
|
||||||
is_disable_model_switch: bool = False
|
is_disable_model_switch: bool = False
|
||||||
|
is_controlnet: bool = False
|
||||||
is_enable_file_manager: bool = False
|
is_enable_file_manager: bool = False
|
||||||
is_enable_auto_saving: bool = False
|
is_enable_auto_saving: bool = False
|
||||||
is_desktop: bool = False
|
is_desktop: bool = False
|
||||||
@ -257,6 +259,7 @@ def process():
|
|||||||
p2p_steps=form["p2pSteps"],
|
p2p_steps=form["p2pSteps"],
|
||||||
p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
|
p2p_image_guidance_scale=form["p2pImageGuidanceScale"],
|
||||||
p2p_guidance_scale=form["p2pGuidanceScale"],
|
p2p_guidance_scale=form["p2pGuidanceScale"],
|
||||||
|
controlnet_conditioning_scale=form["controlnet_conditioning_scale"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.sd_seed == -1:
|
if config.sd_seed == -1:
|
||||||
@ -347,6 +350,12 @@ def current_model():
|
|||||||
return model.name, 200
|
return model.name, 200
|
||||||
|
|
||||||
|
|
||||||
|
@app.route("/is_controlnet")
|
||||||
|
def get_is_controlnet():
|
||||||
|
res = "true" if is_controlnet else "false"
|
||||||
|
return res, 200
|
||||||
|
|
||||||
|
|
||||||
@app.route("/is_disable_model_switch")
|
@app.route("/is_disable_model_switch")
|
||||||
def get_is_disable_model_switch():
|
def get_is_disable_model_switch():
|
||||||
res = "true" if is_disable_model_switch else "false"
|
res = "true" if is_disable_model_switch else "false"
|
||||||
@ -427,6 +436,9 @@ def main(args):
|
|||||||
global thumb
|
global thumb
|
||||||
global output_dir
|
global output_dir
|
||||||
global is_enable_auto_saving
|
global is_enable_auto_saving
|
||||||
|
global is_controlnet
|
||||||
|
if args.sd_controlnet and args.model in SD15_MODELS:
|
||||||
|
is_controlnet = True
|
||||||
|
|
||||||
output_dir = args.output_dir
|
output_dir = args.output_dir
|
||||||
if output_dir is not None:
|
if output_dir is not None:
|
||||||
@ -464,6 +476,7 @@ def main(args):
|
|||||||
|
|
||||||
model = ModelManager(
|
model = ModelManager(
|
||||||
name=args.model,
|
name=args.model,
|
||||||
|
sd_controlnet=args.sd_controlnet,
|
||||||
device=device,
|
device=device,
|
||||||
no_half=args.no_half,
|
no_half=args.no_half,
|
||||||
hf_access_token=args.hf_access_token,
|
hf_access_token=args.hf_access_token,
|
||||||
|
@ -6,25 +6,7 @@ import gradio as gr
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from lama_cleaner.const import (
|
from lama_cleaner.const import *
|
||||||
AVAILABLE_MODELS,
|
|
||||||
AVAILABLE_DEVICES,
|
|
||||||
CPU_OFFLOAD_HELP,
|
|
||||||
NO_HALF_HELP,
|
|
||||||
DISABLE_NSFW_HELP,
|
|
||||||
SD_CPU_TEXTENCODER_HELP,
|
|
||||||
LOCAL_FILES_ONLY_HELP,
|
|
||||||
ENABLE_XFORMERS_HELP,
|
|
||||||
MODEL_DIR_HELP,
|
|
||||||
OUTPUT_DIR_HELP,
|
|
||||||
INPUT_HELP,
|
|
||||||
GUI_HELP,
|
|
||||||
DEFAULT_MODEL,
|
|
||||||
DEFAULT_DEVICE,
|
|
||||||
NO_GUI_AUTO_CLOSE_HELP,
|
|
||||||
DEFAULT_MODEL_DIR,
|
|
||||||
MPS_SUPPORT_MODELS,
|
|
||||||
)
|
|
||||||
|
|
||||||
_config_file = None
|
_config_file = None
|
||||||
|
|
||||||
@ -33,6 +15,7 @@ class Config(BaseModel):
|
|||||||
host: str = "127.0.0.1"
|
host: str = "127.0.0.1"
|
||||||
port: int = 8080
|
port: int = 8080
|
||||||
model: str = DEFAULT_MODEL
|
model: str = DEFAULT_MODEL
|
||||||
|
sd_controlnet: bool = False
|
||||||
device: str = DEFAULT_DEVICE
|
device: str = DEFAULT_DEVICE
|
||||||
gui: bool = False
|
gui: bool = False
|
||||||
no_gui_auto_close: bool = False
|
no_gui_auto_close: bool = False
|
||||||
@ -59,6 +42,7 @@ def save_config(
|
|||||||
host,
|
host,
|
||||||
port,
|
port,
|
||||||
model,
|
model,
|
||||||
|
sd_controlnet,
|
||||||
device,
|
device,
|
||||||
gui,
|
gui,
|
||||||
no_gui_auto_close,
|
no_gui_auto_close,
|
||||||
@ -127,6 +111,9 @@ def main(config_file: str):
|
|||||||
disable_nsfw = gr.Checkbox(
|
disable_nsfw = gr.Checkbox(
|
||||||
init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}"
|
init_config.disable_nsfw, label=f"{DISABLE_NSFW_HELP}"
|
||||||
)
|
)
|
||||||
|
sd_controlnet = gr.Checkbox(
|
||||||
|
init_config.sd_controlnet, label=f"{SD_CONTROLNET_HELP}"
|
||||||
|
)
|
||||||
sd_cpu_textencoder = gr.Checkbox(
|
sd_cpu_textencoder = gr.Checkbox(
|
||||||
init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}"
|
init_config.sd_cpu_textencoder, label=f"{SD_CPU_TEXTENCODER_HELP}"
|
||||||
)
|
)
|
||||||
@ -149,6 +136,7 @@ def main(config_file: str):
|
|||||||
host,
|
host,
|
||||||
port,
|
port,
|
||||||
model,
|
model,
|
||||||
|
sd_controlnet,
|
||||||
device,
|
device,
|
||||||
gui,
|
gui,
|
||||||
no_gui_auto_close,
|
no_gui_auto_close,
|
||||||
|
@ -12,7 +12,7 @@ pytest
|
|||||||
yacs
|
yacs
|
||||||
markupsafe==2.0.1
|
markupsafe==2.0.1
|
||||||
scikit-image==0.19.3
|
scikit-image==0.19.3
|
||||||
diffusers[torch]==0.12.1
|
diffusers[torch]==0.14.0
|
||||||
transformers>=4.25.1
|
transformers>=4.25.1
|
||||||
watchdog==2.2.1
|
watchdog==2.2.1
|
||||||
gradio
|
gradio
|
||||||
|
Loading…
Reference in New Issue
Block a user