ldm add plms sampler

This commit is contained in:
Qing 2022-06-12 13:14:17 +08:00
parent 55197f2209
commit 35b92ba9de
11 changed files with 478 additions and 207 deletions

View File

@ -16,6 +16,7 @@ export default async function inpaint(
fd.append('mask', mask) fd.append('mask', mask)
fd.append('ldmSteps', settings.ldmSteps.toString()) fd.append('ldmSteps', settings.ldmSteps.toString())
fd.append('ldmSampler', settings.ldmSampler.toString())
fd.append('hdStrategy', settings.hdStrategy) fd.append('hdStrategy', settings.hdStrategy)
fd.append('hdStrategyCropMargin', settings.hdStrategyCropMargin.toString()) fd.append('hdStrategyCropMargin', settings.hdStrategyCropMargin.toString())
fd.append( fd.append(

View File

@ -11,6 +11,11 @@ export enum HDStrategy {
CROP = 'Crop', CROP = 'Crop',
} }
export enum LDMSampler {
ddim = 'ddim',
plms = 'plms',
}
function HDSettingBlock() { function HDSettingBlock() {
const [setting, setSettingState] = useRecoilState(settingState) const [setting, setSettingState] = useRecoilState(settingState)

View File

@ -2,6 +2,7 @@ import React, { ReactNode } from 'react'
import { useRecoilState } from 'recoil' import { useRecoilState } from 'recoil'
import { settingState } from '../../store/Atoms' import { settingState } from '../../store/Atoms'
import Selector from '../shared/Selector' import Selector from '../shared/Selector'
import { LDMSampler } from './HDSettingBlock'
import NumberInputSetting from './NumberInputSetting' import NumberInputSetting from './NumberInputSetting'
import SettingBlock from './SettingBlock' import SettingBlock from './SettingBlock'
@ -19,6 +20,12 @@ function ModelSettingBlock() {
}) })
} }
const onLDMSamplerChange = (value: LDMSampler) => {
setSettingState(old => {
return { ...old, ldmSampler: value }
})
}
const renderModelDesc = ( const renderModelDesc = (
name: string, name: string,
paperUrl: string, paperUrl: string,
@ -65,6 +72,19 @@ function ModelSettingBlock() {
}) })
}} }}
/> />
<SettingBlock
className="sub-setting-block"
title="Sampler"
input={
<Selector
width={80}
value={setting.ldmSampler as string}
options={Object.values(LDMSampler)}
onChange={val => onLDMSamplerChange(val as LDMSampler)}
/>
}
/>
</div> </div>
) )
} }

View File

@ -56,7 +56,7 @@ export default function ShortcutsModal() {
/> />
<ShortCut content="Undo Inpainting" keys={[CmdOrCtrl, 'Z']} /> <ShortCut content="Undo Inpainting" keys={[CmdOrCtrl, 'Z']} />
<ShortCut content="Pan" keys={['Space & Drag']} /> <ShortCut content="Pan" keys={['Space & Drag']} />
<ShortCut content="View Original Image" keys={['Hold Tag']} /> <ShortCut content="View Original Image" keys={['Hold Tab']} />
<ShortCut content="Reset Zoom/Pan" keys={['Esc']} /> <ShortCut content="Reset Zoom/Pan" keys={['Esc']} />
<ShortCut content="Cancel Mask Drawing" keys={['Esc']} /> <ShortCut content="Cancel Mask Drawing" keys={['Esc']} />
<ShortCut content="Run Inpainting Manually" keys={['Shift', 'R']} /> <ShortCut content="Run Inpainting Manually" keys={['Shift', 'R']} />

View File

@ -2,8 +2,9 @@
all: unset; all: unset;
flex: 1 0 auto; flex: 1 0 auto;
border-radius: 0.5rem; border-radius: 0.5rem;
padding: 0.4rem 0.8rem; padding: 0 0.8rem;
outline: 1px solid var(--border-color); outline: 1px solid var(--border-color);
height: 36px;
&:focus-visible { &:focus-visible {
outline: 1px solid var(--yellow-accent); outline: 1px solid var(--yellow-accent);

View File

@ -1,5 +1,5 @@
import { atom } from 'recoil' import { atom } from 'recoil'
import { HDStrategy } from '../components/Settings/HDSettingBlock' import { HDStrategy, LDMSampler } from '../components/Settings/HDSettingBlock'
import { AIModel } from '../components/Settings/ModelSettingBlock' import { AIModel } from '../components/Settings/ModelSettingBlock'
import { ToastState } from '../components/shared/Toast' import { ToastState } from '../components/shared/Toast'
@ -43,6 +43,7 @@ export interface Settings {
// For LDM // For LDM
ldmSteps: number ldmSteps: number
ldmSampler: LDMSampler
} }
export const settingStateDefault = { export const settingStateDefault = {
@ -50,6 +51,7 @@ export const settingStateDefault = {
runInpaintingManually: false, runInpaintingManually: false,
model: AIModel.LAMA, model: AIModel.LAMA,
ldmSteps: 50, ldmSteps: 50,
ldmSampler: LDMSampler.plms,
hdStrategy: HDStrategy.RESIZE, hdStrategy: HDStrategy.RESIZE,
hdStrategyResizeLimit: 2048, hdStrategyResizeLimit: 2048,
hdStrategyCropTrigerSize: 2048, hdStrategyCropTrigerSize: 2048,

View File

@ -0,0 +1,193 @@
import torch
import numpy as np
from tqdm import tqdm
from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
from loguru import logger
class DDIMSampler(object):
def __init__(self, model, schedule="linear"):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
# array([1])
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
)
@torch.no_grad()
def sample(self, steps, conditioning, batch_size, shape):
self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# samples: 1,3,128,128
return self.ddim_sampling(
conditioning,
size,
quantize_denoised=False,
ddim_use_original_steps=False,
noise_dropout=0,
temperature=1.0,
)
@torch.no_grad()
def ddim_sampling(
self,
cond,
shape,
ddim_use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
):
device = self.model.betas.device
b = shape[0]
img = torch.randn(shape, device=device, dtype=cond.dtype)
timesteps = (
self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
)
time_range = (
reversed(range(0, timesteps))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
)
img, _ = outs
return img
@torch.no_grad()
def p_sample_ddim(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
):
b, *_, device = *x.shape, x.device
e_t = self.model.apply_model(x, t, c)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised: # 没用
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0: # 没用
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0

View File

@ -5,17 +5,15 @@ import torch
from loguru import logger from loguru import logger
from lama_cleaner.model.base import InpaintModel from lama_cleaner.model.base import InpaintModel
from lama_cleaner.schema import Config from lama_cleaner.model.ddim_sampler import DDIMSampler
from lama_cleaner.model.plms_sampler import PLMSSampler
from lama_cleaner.schema import Config, LDMSampler
torch.manual_seed(42) torch.manual_seed(42)
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm
from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url from lama_cleaner.helper import download_model, norm_img, get_cache_path_by_url
from lama_cleaner.model.utils import ( from lama_cleaner.model.utils import (
make_beta_schedule, make_beta_schedule,
make_ddim_timesteps,
make_ddim_sampling_parameters,
noise_like,
timestep_embedding, timestep_embedding,
) )
@ -94,7 +92,7 @@ class DDPM(nn.Module):
self.linear_start = linear_start self.linear_start = linear_start
self.linear_end = linear_end self.linear_end = linear_end
assert ( assert (
alphas_cumprod.shape[0] == self.num_timesteps alphas_cumprod.shape[0] == self.num_timesteps
), "alphas have to be defined for each timestep" ), "alphas have to be defined for each timestep"
to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device) to_torch = lambda x: torch.tensor(x, dtype=torch.float32).to(self.device)
@ -120,7 +118,7 @@ class DDPM(nn.Module):
# calculations for posterior q(x_{t-1} | x_t, x_0) # calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * ( posterior_variance = (1 - self.v_posterior) * betas * (
1.0 - alphas_cumprod_prev 1.0 - alphas_cumprod_prev
) / (1.0 - alphas_cumprod) + self.v_posterior * betas ) / (1.0 - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer("posterior_variance", to_torch(posterior_variance)) self.register_buffer("posterior_variance", to_torch(posterior_variance))
@ -142,16 +140,16 @@ class DDPM(nn.Module):
if self.parameterization == "eps": if self.parameterization == "eps":
lvlb_weights = self.betas ** 2 / ( lvlb_weights = self.betas ** 2 / (
2 2
* self.posterior_variance * self.posterior_variance
* to_torch(alphas) * to_torch(alphas)
* (1 - self.alphas_cumprod) * (1 - self.alphas_cumprod)
) )
elif self.parameterization == "x0": elif self.parameterization == "x0":
lvlb_weights = ( lvlb_weights = (
0.5 0.5
* np.sqrt(torch.Tensor(alphas_cumprod)) * np.sqrt(torch.Tensor(alphas_cumprod))
/ (2.0 * 1 - torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod))
) )
else: else:
raise NotImplementedError("mu not supported") raise NotImplementedError("mu not supported")
@ -221,192 +219,6 @@ class LatentDiffusion(DDPM):
return x_recon return x_recon
class DDIMSampler(object):
def __init__(self, model, schedule="linear"):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
setattr(self, name, attr)
def make_schedule(
self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True
):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
# array([1])
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose,
)
alphas_cumprod = self.model.alphas_cumprod # torch.Size([1000])
assert (
alphas_cumprod.shape[0] == self.ddpm_num_timesteps
), "alphas have to be defined for each timestep"
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer("betas", to_torch(self.model.betas))
self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod))
self.register_buffer(
"alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev)
)
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer(
"sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_one_minus_alphas_cumprod",
to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())),
)
self.register_buffer(
"log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu()))
)
self.register_buffer(
"sqrt_recipm1_alphas_cumprod",
to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)),
)
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose,
)
self.register_buffer("ddim_sigmas", ddim_sigmas)
self.register_buffer("ddim_alphas", ddim_alphas)
self.register_buffer("ddim_alphas_prev", ddim_alphas_prev)
self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev)
/ (1 - self.alphas_cumprod)
* (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
)
self.register_buffer(
"ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps
)
@torch.no_grad()
def sample(self, steps, conditioning, batch_size, shape):
self.make_schedule(ddim_num_steps=steps, ddim_eta=0, verbose=False)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
# samples: 1,3,128,128
return self.ddim_sampling(
conditioning,
size,
quantize_denoised=False,
ddim_use_original_steps=False,
noise_dropout=0,
temperature=1.0,
)
@torch.no_grad()
def ddim_sampling(
self,
cond,
shape,
ddim_use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
):
device = self.model.betas.device
b = shape[0]
img = torch.randn(shape, device=device, dtype=cond.dtype)
timesteps = (
self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
)
time_range = (
reversed(range(0, timesteps))
if ddim_use_original_steps
else np.flip(timesteps)
)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
logger.info(f"Running DDIM Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
)
img, _ = outs
return img
@torch.no_grad()
def p_sample_ddim(
self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.0,
noise_dropout=0.0,
):
b, *_, device = *x.shape, x.device
e_t = self.model.apply_model(x, t, c)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = (
self.model.alphas_cumprod_prev
if use_original_steps
else self.ddim_alphas_prev
)
sqrt_one_minus_alphas = (
self.model.sqrt_one_minus_alphas_cumprod
if use_original_steps
else self.ddim_sqrt_one_minus_alphas
)
sigmas = (
self.model.ddim_sigmas_for_original_num_steps
if use_original_steps
else self.ddim_sigmas
)
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full(
(b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device
)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised: # 没用
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1.0 - a_prev - sigma_t ** 2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.0: # 没用
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
def load_jit_model(url, device): def load_jit_model(url, device):
model_path = download_model(url) model_path = download_model(url)
logger.info(f"Load LDM model from: {model_path}") logger.info(f"Load LDM model from: {model_path}")
@ -432,8 +244,7 @@ class LDM(InpaintModel):
self.cond_stage_model_decode = self.cond_stage_model_decode.half() self.cond_stage_model_decode = self.cond_stage_model_decode.half()
self.cond_stage_model_encode = self.cond_stage_model_encode.half() self.cond_stage_model_encode = self.cond_stage_model_encode.half()
model = LatentDiffusion(self.diffusion_model, device) self.model = LatentDiffusion(self.diffusion_model, device)
self.sampler = DDIMSampler(model)
@staticmethod @staticmethod
def is_downloaded() -> bool: def is_downloaded() -> bool:
@ -454,6 +265,13 @@ class LDM(InpaintModel):
# image [1,3,512,512] float32 # image [1,3,512,512] float32
# mask: [1,1,512,512] float32 # mask: [1,1,512,512] float32
# masked_image: [1,3,512,512] float32 # masked_image: [1,3,512,512] float32
if config.ldm_sampler == LDMSampler.ddim:
sampler = DDIMSampler(self.model)
elif config.ldm_sampler == LDMSampler.plms:
sampler = PLMSSampler(self.model)
else:
raise ValueError()
steps = config.ldm_steps steps = config.ldm_steps
image = norm_img(image) image = norm_img(image)
mask = norm_img(mask) mask = norm_img(mask)
@ -465,7 +283,6 @@ class LDM(InpaintModel):
mask = torch.from_numpy(mask).unsqueeze(0).to(self.device) mask = torch.from_numpy(mask).unsqueeze(0).to(self.device)
masked_image = (1 - mask) * image masked_image = (1 - mask) * image
image = self._norm(image)
mask = self._norm(mask) mask = self._norm(mask)
masked_image = self._norm(masked_image) masked_image = self._norm(masked_image)
@ -476,7 +293,7 @@ class LDM(InpaintModel):
c = torch.cat((c, cc), dim=1) # 1,4,128,128 c = torch.cat((c, cc), dim=1) # 1,4,128,128
shape = (c.shape[1] - 1,) + c.shape[2:] shape = (c.shape[1] - 1,) + c.shape[2:]
samples_ddim = self.sampler.sample( samples_ddim = sampler.sample(
steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape steps=steps, conditioning=c, batch_size=c.shape[0], shape=shape
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -0,0 +1,225 @@
# From: https://github.com/CompVis/latent-diffusion/blob/main/ldm/models/diffusion/plms.py
import torch
import numpy as np
from lama_cleaner.model.utils import make_ddim_timesteps, make_ddim_sampling_parameters, noise_like
from tqdm import tqdm
class PLMSSampler(object):
def __init__(self, model, schedule="linear", **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
setattr(self, name, attr)
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta, verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
1 - self.alphas_cumprod / self.alphas_cumprod_prev))
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
steps,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=False,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
**kwargs
):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
else:
if conditioning.shape[0] != batch_size:
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
self.make_schedule(ddim_num_steps=steps, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
samples = self.plms_sampling(conditioning, size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask, x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
)
return samples
@torch.no_grad()
def plms_sampling(self, cond, shape,
x_T=None, ddim_use_original_steps=False,
callback=None, timesteps=None, quantize_denoised=False,
mask=None, x0=None, img_callback=None, log_every_t=100,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, ):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
time_range = list(reversed(range(0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
print(f"Running PLMS Sampling with {total_steps} timesteps")
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b,), step, device=device, dtype=torch.long)
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised, temperature=temperature,
noise_dropout=noise_dropout, score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps, t_next=ts_next)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback: callback(i)
if img_callback: img_callback(pred_x0, i)
return img
@torch.no_grad()
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None,
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == "eps"
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t

View File

@ -9,8 +9,14 @@ class HDStrategy(str, Enum):
CROP = 'Crop' CROP = 'Crop'
class LDMSampler(str, Enum):
ddim = 'ddim'
plms = 'plms'
class Config(BaseModel): class Config(BaseModel):
ldm_steps: int ldm_steps: int
ldm_sampler: str
hd_strategy: str hd_strategy: str
hd_strategy_crop_margin: int hd_strategy_crop_margin: int
hd_strategy_crop_trigger_size: int hd_strategy_crop_trigger_size: int

View File

@ -93,6 +93,7 @@ def process():
config = Config( config = Config(
ldm_steps=form["ldmSteps"], ldm_steps=form["ldmSteps"],
ldm_sampler=form["ldmSampler"],
hd_strategy=form["hdStrategy"], hd_strategy=form["hdStrategy"],
hd_strategy_crop_margin=form["hdStrategyCropMargin"], hd_strategy_crop_margin=form["hdStrategyCropMargin"],
hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"], hd_strategy_crop_trigger_size=form["hdStrategyCropTrigerSize"],